Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt] Cache loop-invariant global vars to local vars #6072

Merged
merged 16 commits into from
Sep 23, 2022
2 changes: 2 additions & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ bool whole_kernel_cse(IRNode *root);
bool extract_constant(IRNode *root, const CompileConfig &config);
bool unreachable_code_elimination(IRNode *root);
bool loop_invariant_code_motion(IRNode *root, const CompileConfig &config);
bool cache_loop_invariant_global_vars(IRNode *root,
const CompileConfig &config);
void full_simplify(IRNode *root,
const CompileConfig &config,
const FullSimplifyPass::Args &args);
Expand Down
1 change: 1 addition & 0 deletions taichi/program/compile_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct CompileConfig {
bool lower_access;
bool simplify_after_lower_access;
bool move_loop_invariant_outside_if;
bool cache_loop_invariant_global_vars{true};
bool demote_dense_struct_fors;
bool advanced_optimization;
bool constant_folding;
Expand Down
2 changes: 2 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ void export_lang(py::module &m) {
.def_readwrite("lower_access", &CompileConfig::lower_access)
.def_readwrite("move_loop_invariant_outside_if",
&CompileConfig::move_loop_invariant_outside_if)
.def_readwrite("cache_loop_invariant_global_vars",
&CompileConfig::cache_loop_invariant_global_vars)
.def_readwrite("default_cpu_block_dim",
&CompileConfig::default_cpu_block_dim)
.def_readwrite("cpu_block_dim_adaptive",
Expand Down
197 changes: 197 additions & 0 deletions taichi/transforms/cache_loop_invariant_global_vars.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#include "taichi/transforms/loop_invariant_detector.h"
#include "taichi/ir/analysis.h"

TLANG_NAMESPACE_BEGIN
lin-hitonami marked this conversation as resolved.
Show resolved Hide resolved

class CacheLoopInvariantGlobalVars : public LoopInvariantDetector {
public:
using LoopInvariantDetector::visit;

enum class CacheStatus {
None = 0,
Read = 1,
Write = 2,
ReadWrite = 3,
};

typedef std::unordered_map<Stmt *,
std::pair<CacheStatus, std::vector<Stmt *>>>
CacheMap;
std::stack<CacheMap> cached_maps;

DelayedIRModifier modifier;
std::unordered_map<const SNode *, GlobalPtrStmt *> loop_unique_ptr_;
std::unordered_map<int, ExternalPtrStmt *> loop_unique_arr_ptr_;

OffloadedStmt *current_offloaded;

explicit CacheLoopInvariantGlobalVars(const CompileConfig &config)
: LoopInvariantDetector(config) {
}

void visit(OffloadedStmt *stmt) override {
if (stmt->task_type == OffloadedTaskType::range_for ||
stmt->task_type == OffloadedTaskType::mesh_for ||
stmt->task_type == OffloadedTaskType::struct_for) {
auto uniquely_accessed_pointers =
irpass::analysis::gather_uniquely_accessed_pointers(stmt);
loop_unique_ptr_ = std::move(uniquely_accessed_pointers.first);
loop_unique_arr_ptr_ = std::move(uniquely_accessed_pointers.second);
}
current_offloaded = stmt;
// We don't need to visit TLS/BLS prologues/epilogues.
if (stmt->body) {
if (stmt->task_type == OffloadedStmt::TaskType::range_for ||
stmt->task_type == OffloadedTaskType::mesh_for ||
stmt->task_type == OffloadedStmt::TaskType::struct_for)
visit_loop(stmt->body.get());
else
stmt->body->accept(this);
}
current_offloaded = nullptr;
}

bool is_offload_unique(Stmt *stmt) {
if (current_offloaded->task_type == OffloadedTaskType::serial) {
return true;
}
if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
auto snode = global_ptr->snode;
if (loop_unique_ptr_[snode] == nullptr ||
loop_unique_ptr_[snode]->indices.empty()) {
// not uniquely accessed
return false;
}
if (current_offloaded->mem_access_opt.has_flag(
snode, SNodeAccessFlag::block_local) ||
current_offloaded->mem_access_opt.has_flag(
snode, SNodeAccessFlag::mesh_local)) {
// BLS does not support write access yet so we keep atomic_adds.
return false;
}
return true;
} else if (stmt->is<ExternalPtrStmt>()) {
ExternalPtrStmt *dest_ptr = stmt->as<ExternalPtrStmt>();
if (dest_ptr->indices.empty()) {
return false;
}
ArgLoadStmt *arg_load_stmt = dest_ptr->base_ptr->as<ArgLoadStmt>();
int arg_id = arg_load_stmt->arg_id;
if (loop_unique_arr_ptr_[arg_id] == nullptr) {
// Not loop unique
return false;
}
return true;
// TODO: Is BLS / Mem Access Opt a thing for any_arr?
}
return false;
}

void visit_loop(Block *body) override {
cached_maps.emplace();

loop_blocks.push(body);

body->accept(this);

for (auto &[dest, status_vec] : cached_maps.top()) {
auto &[status, vec] = status_vec;
auto alloca_unique =
std::make_unique<AllocaStmt>(dest->ret_type.ptr_removed());
auto alloca_stmt = alloca_unique.get();
modifier.insert_before(body->parent_stmt, std::move(alloca_unique));
if (int(status) & int(CacheStatus::Read)) {
set_init_value(alloca_stmt, dest);
}
if (int(status) & int(CacheStatus::Write)) {
add_writeback(alloca_stmt, dest);
}
for (auto *stmt : vec) {
if (stmt->is<GlobalLoadStmt>()) {
auto local_load = std::make_unique<LocalLoadStmt>(alloca_stmt);
stmt->replace_usages_with(local_load.get());
modifier.insert_before(stmt, std::move(local_load));
modifier.erase(stmt);
} else if (auto *global_store = stmt->cast<GlobalStoreStmt>()) {
auto local_store =
std::make_unique<LocalStoreStmt>(alloca_stmt, global_store->val);
stmt->replace_usages_with(local_store.get());
modifier.insert_before(stmt, std::move(local_store));
modifier.erase(stmt);
} else {
TI_UNREACHABLE
}
}
}
loop_blocks.pop();
cached_maps.pop();
}

void add_writeback(AllocaStmt *alloca_stmt, Stmt *global_var) {
auto final_value = std::make_unique<LocalLoadStmt>(alloca_stmt);
auto global_store =
std::make_unique<GlobalStoreStmt>(global_var, final_value.get());
modifier.insert_after(current_loop_stmt(), std::move(global_store));
modifier.insert_after(current_loop_stmt(), std::move(final_value));
}

void set_init_value(AllocaStmt *alloca_stmt, Stmt *global_var) {
auto new_global_load = std::make_unique<GlobalLoadStmt>(global_var);
auto local_store =
std::make_unique<LocalStoreStmt>(alloca_stmt, new_global_load.get());
modifier.insert_before(current_loop_stmt(), std::move(new_global_load));
modifier.insert_before(current_loop_stmt(), std::move(local_store));
}

void cache_global_to_local(Stmt *stmt, Stmt *dest, CacheStatus status) {
if (auto &[cached_status, vec] = cached_maps.top()[dest];
cached_status != CacheStatus::None) {
// The global variable has already been cached.
if (cached_status == CacheStatus::Read && status == CacheStatus::Write) {
cached_status = CacheStatus::ReadWrite;
}
vec.push_back(stmt);
return;
}
cached_maps.top()[dest] = {status, {stmt}};
}

void visit(GlobalLoadStmt *stmt) override {
if (is_offload_unique(stmt->src) &&
is_operand_loop_invariant(stmt->src, stmt->parent)) {
cache_global_to_local(stmt, stmt->src, CacheStatus::Read);
}
}

void visit(GlobalStoreStmt *stmt) override {
if (is_offload_unique(stmt->dest) &&
is_operand_loop_invariant(stmt->dest, stmt->parent)) {
cache_global_to_local(stmt, stmt->dest, CacheStatus::Write);
}
}

static bool run(IRNode *node, const CompileConfig &config) {
bool modified = false;

while (true) {
CacheLoopInvariantGlobalVars eliminator(config);
node->accept(&eliminator);
if (eliminator.modifier.modify_ir())
modified = true;
else
break;
};

return modified;
}
};

namespace irpass {
bool cache_loop_invariant_global_vars(IRNode *root,
const CompileConfig &config) {
TI_AUTO_PROF;
return CacheLoopInvariantGlobalVars::run(root, config);
}
} // namespace irpass

TLANG_NAMESPACE_END
7 changes: 7 additions & 0 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ void offload_to_executable(IRNode *ir,
irpass::demote_atomics(ir, config);
print("Atomics demoted I");
irpass::analysis::verify(ir);
if (config.cache_loop_invariant_global_vars) {
irpass::cache_loop_invariant_global_vars(ir, config);
print("Cache loop-invariant global vars");
}

if (config.demote_dense_struct_fors) {
irpass::demote_dense_struct_fors(ir, config.packed);
Expand Down Expand Up @@ -246,6 +250,9 @@ void offload_to_executable(IRNode *ir,
irpass::analysis::verify(ir);

if (lower_global_access) {
irpass::full_simplify(ir, config,
{false, /*autodiff_enabled*/ false, kernel->program});
print("Simplified before lower access");
irpass::lower_access(ir, config, {kernel->no_activate, true});
print("Access lowered");
irpass::analysis::verify(ir);
Expand Down
Loading