Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 19, 2022
1 parent 64450a6 commit 7fc61b8
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions taichi/transforms/cache_loop_invariant_global_vars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector {
public:
using LoopInvariantDetector::visit;

enum class CacheStatus { None = 0, Read = 1, Write = 2, ReadWrite = 3, HasAtomic = 4 };

typedef std::unordered_map<Stmt *, std::pair<CacheStatus, std::vector<Stmt *>>> CacheMap;
enum class CacheStatus {
None = 0,
Read = 1,
Write = 2,
ReadWrite = 3,
HasAtomic = 4
};

typedef std::unordered_map<Stmt *,
std::pair<CacheStatus, std::vector<Stmt *>>>
CacheMap;
std::stack<CacheMap> cached_maps;

DelayedIRModifier modifier;
Expand Down Expand Up @@ -108,8 +116,7 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector {
stmt->replace_usages_with(local_load.get());
modifier.insert_before(stmt, std::move(local_load));
modifier.erase(stmt);
}
else if (auto *global_store = stmt->cast<GlobalStoreStmt>()) {
} else if (auto *global_store = stmt->cast<GlobalStoreStmt>()) {
auto local_store =
std::make_unique<LocalStoreStmt>(alloca_stmt, global_store->val);
stmt->replace_usages_with(local_store.get());
Expand Down Expand Up @@ -140,12 +147,11 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector {
modifier.insert_before(current_loop_stmt(), std::move(local_store));
}

void cache_global_to_local(Stmt *stmt, Stmt *dest,
CacheStatus status) {
if (auto &[cached_status, vec] = cached_maps.top()[dest]; cached_status != CacheStatus::None) {
void cache_global_to_local(Stmt *stmt, Stmt *dest, CacheStatus status) {
if (auto &[cached_status, vec] = cached_maps.top()[dest];
cached_status != CacheStatus::None) {
// The global variable has already been cached.
if (cached_status == CacheStatus::Read &&
status == CacheStatus::Write) {
if (cached_status == CacheStatus::Read && status == CacheStatus::Write) {
cached_status = CacheStatus::ReadWrite;
}
vec.push_back(stmt);
Expand All @@ -155,17 +161,16 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector {
}

void visit(GlobalLoadStmt *stmt) override {
if (is_offload_unique(stmt->src) && is_operand_loop_invariant_impl(stmt->src, stmt->parent)) {
cache_global_to_local(stmt,
stmt->src, CacheStatus::Read);
if (is_offload_unique(stmt->src) &&
is_operand_loop_invariant_impl(stmt->src, stmt->parent)) {
cache_global_to_local(stmt, stmt->src, CacheStatus::Read);
}
}

void visit(GlobalStoreStmt *stmt) override {
if (is_offload_unique(stmt->dest) && is_operand_loop_invariant_impl(stmt->dest, stmt->parent)) {
cache_global_to_local(stmt,
stmt->dest, CacheStatus::Write);

if (is_offload_unique(stmt->dest) &&
is_operand_loop_invariant_impl(stmt->dest, stmt->parent)) {
cache_global_to_local(stmt, stmt->dest, CacheStatus::Write);
}
}

Expand Down

0 comments on commit 7fc61b8

Please sign in to comment.