Skip to content

Commit

Permalink
[async] Optimization stage 1 (#994)
Browse files Browse the repository at this point in the history
Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
yuanming-hu and taichi-gardener authored May 15, 2020
1 parent 65c151a commit 333c86c
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 39 deletions.
75 changes: 75 additions & 0 deletions misc/test_fuse_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import taichi as ti
import time

ti.init(async=True)

x = ti.var(ti.i32)
y = ti.var(ti.i32)
z = ti.var(ti.i32)

ti.root.dense(ti.i, 1024**3).place(x)
ti.root.dense(ti.i, 1024**3).place(y)
ti.root.dense(ti.i, 1024**3).place(z)


@ti.kernel
def x_to_y():
for i in x:
y[i] = x[i] + 1


@ti.kernel
def y_to_z():
for i in x:
z[i] = y[i] + 4


@ti.kernel
def inc():
for i in x:
x[i] = x[i] + 1


n = 100

for i in range(n):
x[i] = i * 10

repeat = 10

for i in range(repeat):
t = time.time()
x_to_y()
ti.sync()
print('x_to_y', time.time() - t)

for i in range(repeat):
t = time.time()
y_to_z()
ti.sync()
print('y_to_z', time.time() - t)

for i in range(repeat):
t = time.time()
x_to_y()
y_to_z()
ti.sync()
print('fused x->y->z', time.time() - t)

for i in range(repeat):
t = time.time()
inc()
ti.sync()
print('single inc', time.time() - t)

for i in range(repeat):
t = time.time()
for j in range(10):
inc()
ti.sync()
print('fused 10 inc', time.time() - t)

for i in range(n):
assert x[i] == i * 10
assert y[i] == x[i] + 1
assert z[i] == x[i] + 5
35 changes: 35 additions & 0 deletions misc/test_fuse_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import taichi as ti

ti.init()

x = ti.var(ti.i32)
y = ti.var(ti.i32)
z = ti.var(ti.i32)

ti.root.dynamic(ti.i, 1048576, chunk_size=2048).place(x, y, z)


@ti.kernel
def x_to_y():
for i in x:
y[i] = x[i] + 1


@ti.kernel
def y_to_z():
for i in x:
z[i] = y[i] + 1


n = 10000

for i in range(n):
x[i] = i * 10

x_to_y()
y_to_z()

for i in range(n):
x[i] = i * 10
assert y[i] == x[i] + 1
assert z[i] == x[i] + 2
18 changes: 10 additions & 8 deletions taichi/analysis/clone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include "taichi/ir/analysis.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
#include "taichi/program/program.h"

#include <unordered_map>

TLANG_NAMESPACE_BEGIN
Expand All @@ -12,10 +14,7 @@ class IRCloner : public IRVisitor {
std::unordered_map<Stmt *, Stmt *> operand_map;

public:
enum Phase {
register_operand_map,
replace_operand
} phase;
enum Phase { register_operand_map, replace_operand } phase;

explicit IRCloner(IRNode *other_node)
: other_node(other_node), phase(register_operand_map) {
Expand Down Expand Up @@ -111,22 +110,25 @@ class IRCloner : public IRVisitor {
}
}

static std::unique_ptr<IRNode> run(IRNode *root) {
static std::unique_ptr<IRNode> run(IRNode *root, Kernel *kernel) {
if (kernel == nullptr) {
kernel = &get_current_program().get_current_kernel();
}
std::unique_ptr<IRNode> new_root = root->clone();
IRCloner cloner(new_root.get());
cloner.phase = IRCloner::register_operand_map;
root->accept(&cloner);
cloner.phase = IRCloner::replace_operand;
root->accept(&cloner);
irpass::typecheck(new_root.get());
irpass::typecheck(new_root.get(), kernel);
irpass::fix_block_parents(new_root.get());
return new_root;
}
};

namespace irpass::analysis {
std::unique_ptr<IRNode> clone(IRNode *root) {
return IRCloner::run(root);
std::unique_ptr<IRNode> clone(IRNode *root, Kernel *kernel) {
return IRCloner::run(root, kernel);
}
} // namespace irpass::analysis

Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class DiffRange {
namespace irpass::analysis {

void check_fields_registered(IRNode *root);
std::unique_ptr<IRNode> clone(IRNode *root);
std::unique_ptr<IRNode> clone(IRNode *root, Kernel *kernel = nullptr);
int count_statements(IRNode *root);
std::unordered_set<Stmt *> detect_fors_with_break(IRNode *root);
std::unordered_set<Stmt *> detect_loops_with_continue(IRNode *root);
Expand Down
Loading

0 comments on commit 333c86c

Please sign in to comment.