Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Perf] Improve dynamic SNode performance #1182

Merged
merged 2 commits into from
Jun 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/contributor_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ Existing tags:
- ``[Windows]``: Windows platform;
- ``[PyPI]``: PyPI package release;
- ``[Workflow]``: GitHub Actions/Workflows;
- ``[Perf]``: performance improvements;
- ``[Misc]``: something that doesn't belong to any category, such as version bump, reformatting;
- ``[Bug]``: bug fixes;
- **When introducing a new tag, please update the list here in the first PR with that tag, so that people can follow.**
Expand Down
1 change: 1 addition & 0 deletions python/taichi/make_changelog.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def format(c):
'linux': 'Linux',
'mac': 'Mac OS X',
'windows': 'Windows',
'perf': 'Performance improvements',
'release': '',
}

Expand Down
6 changes: 5 additions & 1 deletion taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ void compile_to_offloads(IRNode *ir,
print("Store forwarded");
irpass::analysis::verify(ir);

irpass::flag_access(ir);
print("Access flagged I");
irpass::analysis::verify(ir);

if (lower_global_access) {
irpass::lower_access(ir, true);
print("Access lowered");
Expand All @@ -96,7 +100,7 @@ void compile_to_offloads(IRNode *ir,
irpass::analysis::verify(ir);

irpass::flag_access(ir);
print("Access flagged");
print("Access flagged II");
irpass::analysis::verify(ir);

irpass::constant_fold(ir);
Expand Down
31 changes: 24 additions & 7 deletions taichi/transforms/flag_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class WeakenAccess : public BasicStmtVisitor {
WeakenAccess(IRNode *node) {
allow_undefined_visitor = true;
invoke_default_visitor = false;
current_struct_for = nullptr;
current_offload = nullptr;
node->accept(this);
}

Expand All @@ -86,6 +88,12 @@ class WeakenAccess : public BasicStmtVisitor {
}
}

void visit(StructForStmt *stmt) {
current_struct_for = stmt;
stmt->body->accept(this);
current_struct_for = nullptr;
}

void visit(OffloadedStmt *stmt) {
current_offload = stmt;
if (stmt->body)
Expand All @@ -95,20 +103,28 @@ class WeakenAccess : public BasicStmtVisitor {

void visit(GlobalPtrStmt *stmt) {
if (stmt->activate) {
if (current_offload &&
current_offload->task_type == OffloadedStmt::TaskType::struct_for) {
bool is_struct_for =
(current_offload &&
current_offload->task_type == OffloadedStmt::TaskType::struct_for) ||
current_struct_for;
if (is_struct_for) {
bool same_as_loop_snode = true;
for (auto snode : stmt->snodes.data) {
if (snode->type == SNodeType::place) {
snode = snode->parent;
}
if (snode != current_offload->snode) {
SNode *loop_snode = nullptr;
if (current_struct_for) {
loop_snode = current_struct_for->snode;
} else {
loop_snode = current_offload->snode;
}
if (snode != loop_snode) {
same_as_loop_snode = false;
}
if (stmt->indices.size() ==
current_offload->snode->num_active_indices)
for (int i = 0; i < current_offload->snode->num_active_indices;
i++) {
TI_ASSERT(loop_snode);
if (stmt->indices.size() == loop_snode->num_active_indices)
for (int i = 0; i < loop_snode->num_active_indices; i++) {
auto ind = stmt->indices[i];
// TODO: vectorized cases?
if (auto loop_var = ind->cast<LoopIndexStmt>()) {
Expand All @@ -128,6 +144,7 @@ class WeakenAccess : public BasicStmtVisitor {

private:
OffloadedStmt *current_offload;
StructForStmt *current_struct_for;
};

namespace irpass {
Expand Down
6 changes: 5 additions & 1 deletion taichi/transforms/lower_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class LowerAccess : public IRVisitor {

void visit(GlobalLoadStmt *stmt) override {
if (stmt->ptr->is<GlobalPtrStmt>()) {
// No need to activate for all read accesses
auto lowered = lower_vector_ptr(stmt->ptr->as<GlobalPtrStmt>(), false);
stmt->ptr = lowered.back().get();
stmt->parent->insert_before(stmt, std::move(lowered));
Expand All @@ -186,7 +187,10 @@ class LowerAccess : public IRVisitor {

void visit(GlobalStoreStmt *stmt) override {
if (stmt->ptr->is<GlobalPtrStmt>()) {
auto lowered = lower_vector_ptr(stmt->ptr->as<GlobalPtrStmt>(), true);
auto ptr = stmt->ptr->as<GlobalPtrStmt>();
// If ptr already has activate = false, no need to activate all the
// generated micro-access ops. Otherwise, activate the nodes.
auto lowered = lower_vector_ptr(ptr, ptr->activate);
stmt->ptr = lowered.back().get();
stmt->parent->insert_before(stmt, std::move(lowered));
throw IRModified();
Expand Down