Skip to content

Commit

Permalink
[Bug] Fix Loop-Invariant-Cache for dynamic indexed pointers (#8577)
Browse files Browse the repository at this point in the history
Issue: #

### Brief Summary

copilot:summary

### Walkthrough

copilot:walkthrough
  • Loading branch information
jim19930609 authored Aug 8, 2024
1 parent 37a0563 commit 58134e4
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 1 deletion.
2 changes: 1 addition & 1 deletion requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ numpy
psutil
autograd
requests
matplotlib
matplotlib<=3.7.3
cffi
scipy
setproctitle
Expand Down
124 changes: 124 additions & 0 deletions taichi/analysis/gather_dynamically_indexed_pointers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#include "taichi/analysis/gather_uniquely_accessed_pointers.h"
#include "taichi/ir/ir.h"
#include "taichi/ir/analysis.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/visitors.h"
#include <algorithm>

namespace taichi::lang {

bool is_leaf_nodes_on_same_branch(SNode *snode0, SNode *snode1) {
// Verify: place snode
if (!snode0->is_place() || !snode1->is_place()) {
return false;
}

// Check parent snode
if (snode0->parent != snode1->parent) {
return false;
}

return true;
}

class DynamicIndexingAnalyzer : public BasicStmtVisitor {
void record_dynamic_indexed_ptr(ExternalPtrStmt *extern_ptr) {
dynamically_indexed_ptrs_.insert(extern_ptr);
// Find aliased ExternPtrStmt
for (auto *other_extern_ptr : extern_ptrs_) {
if (other_extern_ptr != extern_ptr &&
other_extern_ptr->base_ptr == extern_ptr->base_ptr) {
// Aliased ExternalPtrStmt, with same base_ptr and outter index
dynamically_indexed_ptrs_.insert(other_extern_ptr);
}
}
}

void record_dynamic_indexed_ptr(GlobalPtrStmt *global_ptr) {
dynamically_indexed_ptrs_.insert(global_ptr);
// Find aliased GlobalPtrStmt
for (auto *other_global_ptr : global_ptrs_) {
if (other_global_ptr != global_ptr &&
is_leaf_nodes_on_same_branch(other_global_ptr->snode,
global_ptr->snode)) {
dynamically_indexed_ptrs_.insert(other_global_ptr);
}
}
}

public:
explicit DynamicIndexingAnalyzer(IRNode *node) {
}

void visit(GlobalPtrStmt *stmt) override {
for (auto *index_stmt : stmt->indices) {
if (!index_stmt->is<ConstStmt>() && !index_stmt->is<LoopIndexStmt>()) {
record_dynamic_indexed_ptr(stmt);
}
}

global_ptrs_.insert(stmt);
}

void visit(ExternalPtrStmt *stmt) override {
for (auto *index_stmt : stmt->indices) {
if (!index_stmt->is<ConstStmt>() && !index_stmt->is<LoopIndexStmt>()) {
record_dynamic_indexed_ptr(stmt);
}
}

extern_ptrs_.insert(stmt);
}

void visit(MatrixPtrStmt *stmt) override {
GlobalPtrStmt *global_ptr = nullptr;
ExternalPtrStmt *extern_ptr = nullptr;

if (stmt->origin->is<GlobalPtrStmt>()) {
global_ptr = stmt->origin->as<GlobalPtrStmt>();
} else if (stmt->origin->is<ExternalPtrStmt>()) {
extern_ptr = stmt->origin->as<ExternalPtrStmt>();
} else {
return;
}

// Is dynamic index
if (stmt->offset->is<ConstStmt>()) {
return;
}

if (global_ptr) {
record_dynamic_indexed_ptr(global_ptr);
}

if (extern_ptr) {
record_dynamic_indexed_ptr(extern_ptr);
}
}

std::unordered_set<Stmt *> get_dynamically_indexed_ptrs() {
return dynamically_indexed_ptrs_;
}

private:
using BasicStmtVisitor::visit;
std::unordered_set<Stmt *> dynamically_indexed_ptrs_;
std::unordered_set<GlobalPtrStmt *> global_ptrs_;
std::unordered_set<ExternalPtrStmt *> extern_ptrs_;
};

namespace irpass::analysis {

std::unordered_set<Stmt *> gather_dynamically_indexed_pointers(IRNode *root) {
DynamicIndexingAnalyzer pass(root);

// This pass is intended to run twice
root->accept(&pass);
root->accept(&pass);

auto dynamically_indexed_ptrs = pass.get_dynamically_indexed_ptrs();
return dynamically_indexed_ptrs;
}

} // namespace irpass::analysis
} // namespace taichi::lang
1 change: 1 addition & 0 deletions taichi/ir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ std::tuple<std::unordered_map<const SNode *, GlobalPtrStmt *>,
hashing::Hasher<std::vector<int>>>,
std::unordered_set<MatrixPtrStmt *>>
gather_uniquely_accessed_pointers(IRNode *root);
std::unordered_set<Stmt *> gather_dynamically_indexed_pointers(IRNode *root);

std::unique_ptr<std::unordered_set<AtomicOpStmt *>> gather_used_atomics(
IRNode *root);
Expand Down
30 changes: 30 additions & 0 deletions taichi/transforms/cache_loop_invariant_global_vars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector {
loop_unique_arr_ptr_;
std::unordered_set<MatrixPtrStmt *> loop_unique_matrix_ptr_;

std::unordered_set<Stmt *> dynamic_indexed_ptrs_;

OffloadedStmt *current_offloaded;

explicit CacheLoopInvariantGlobalVars(const CompileConfig &config)
Expand All @@ -44,6 +46,9 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector {
std::move(std::get<2>(uniquely_accessed_pointers));
}
current_offloaded = stmt;
dynamic_indexed_ptrs_ =
irpass::analysis::gather_dynamically_indexed_pointers(stmt);

// We don't need to visit TLS/BLS prologues/epilogues.
if (stmt->body) {
if (stmt->task_type == OffloadedStmt::TaskType::range_for ||
Expand All @@ -56,6 +61,28 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector {
current_offloaded = nullptr;
}

bool is_dynamically_indexed(Stmt *stmt) {
// Handle GlobalPtrStmt
Stmt *ptr_stmt = nullptr;
if (stmt->is<GlobalPtrStmt>()) {
ptr_stmt = stmt->as<GlobalPtrStmt>();
} else if (stmt->is<MatrixPtrStmt>() &&
stmt->as<MatrixPtrStmt>()->origin->is<GlobalPtrStmt>()) {
ptr_stmt = stmt->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
} else if (stmt->is<ExternalPtrStmt>()) {
ptr_stmt = stmt->as<ExternalPtrStmt>();
} else if (stmt->is<MatrixPtrStmt>() &&
stmt->as<MatrixPtrStmt>()->origin->is<ExternalPtrStmt>()) {
ptr_stmt = stmt->as<MatrixPtrStmt>()->origin->as<ExternalPtrStmt>();
}

if (ptr_stmt && dynamic_indexed_ptrs_.count(ptr_stmt)) {
return true;
}

return false;
}

bool is_offload_unique(Stmt *stmt) {
if (current_offloaded->task_type == OffloadedTaskType::serial) {
return true;
Expand Down Expand Up @@ -174,6 +201,9 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector {

std::optional<int> find_cache_depth_if_cacheable(Stmt *operand,
Block *current_scope) {
if (is_dynamically_indexed(operand)) {
return std::nullopt;
}
if (!is_offload_unique(operand)) {
return std::nullopt;
}
Expand Down
27 changes: 27 additions & 0 deletions tests/python/test_cache_loop_invariant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
from taichi.lang import impl

import taichi as ti
from tests import test_utils


@test_utils.test(arch=[ti.cuda, ti.cpu])
def test_local_matrix_non_constant_index_real_matrix():
N = 1
x = ti.Vector.field(3, float, shape=1)

@ti.kernel
def test_invariant_cache():
for i in range(1):
x[i][1] = x[i][1] + 1.0
for j in range(1):
x[i][1] = x[i][1] - 5.0
for z in range(1):
idx = 0
if z == 0:
idx = 1
x_print = x[i][idx]

assert x_print == x[i][1]

test_invariant_cache()

0 comments on commit 58134e4

Please sign in to comment.