Skip to content

Commit

Permalink
[async] Clone offloaded tasks lazily by maintaining a cached template…
Browse files Browse the repository at this point in the history
… task (#1619)
  • Loading branch information
k-ye authored Aug 1, 2020
1 parent 0dc7ed7 commit 81437d6
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 38 deletions.
2 changes: 1 addition & 1 deletion taichi/analysis/clone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ class IRCloner : public IRVisitor {
root->accept(&cloner);

using namespace irpass;
typecheck(new_root.get());
fix_block_parents(new_root.get());
fix_root_block_kernel(new_root.get(), kernel);
typecheck(new_root.get());
return new_root;
}
};
Expand Down
77 changes: 48 additions & 29 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,51 @@ uint64 hash(IRNode *stmt) {
return ret;
}

std::unique_ptr<IRNode> clone_offloaded_task(OffloadedStmt *from,
Kernel *kernel,
Block *dummy_root) {
std::unique_ptr<OffloadedStmt> clone_offloaded_task(OffloadedStmt *from,
Kernel *kernel,
Block *dummy_root) {
auto new_ir = irpass::analysis::clone(from, kernel);
// This is not the ideal fix, because |new_ir|'s children blocks are NOT
// linked to |dummy_root|. However, if I manually do the linking, I got error
// during LLVM codegen.
new_ir->as<OffloadedStmt>()->parent = dummy_root;
return new_ir;
return std::unique_ptr<OffloadedStmt>((OffloadedStmt *)(new_ir.release()));
}

} // namespace

KernelLaunchRecord::KernelLaunchRecord(Context context,
Kernel *kernel,
std::unique_ptr<IRNode> &&stmt_,
uint64 h)
OffloadedStmt *stmt,
uint64 h,
Block *dummy_root)
: context(context),
kernel(kernel),
stmt(dynamic_cast<OffloadedStmt *>(stmt_.get())),
h(h),
stmt_holder_(std::move(stmt_)) {
TI_ASSERT(stmt != nullptr);
TI_ASSERT(stmt->get_kernel() != nullptr);
stmt_(stmt),
dummy_root_(dummy_root),
cloned_stmt_holder_(nullptr) {
TI_ASSERT(stmt_ != nullptr);
TI_ASSERT(stmt_->get_kernel() != nullptr);
}

OffloadedStmt *KernelLaunchRecord::clone_stmt_on_write() {
if (cloned_stmt_holder_ == nullptr) {
cloned_stmt_holder_ = clone_offloaded_task(stmt_, kernel, dummy_root_);
stmt_ = cloned_stmt_holder_.get();
}
return stmt_;
}

void ExecutionQueue::enqueue(KernelLaunchRecord &&ker) {
auto h = ker.h;
auto stmt = ker.stmt;
auto *stmt = ker.stmt();
auto kernel = ker.kernel;
if (compiled_func.find(h) == compiled_func.end() &&
to_be_compiled.find(h) == to_be_compiled.end()) {
to_be_compiled.insert(h);
// Later the IR passes will change |stmt|, so we must clone it.
stmt = ker.clone_stmt_on_write();
compilation_workers.enqueue([&, stmt, kernel, h, this]() {
{
// Final lowering
Expand All @@ -82,7 +94,7 @@ void ExecutionQueue::enqueue(KernelLaunchRecord &&ker) {
}

auto context = ker.context;
launch_worker.enqueue([&, h, stmt, context, this] {
launch_worker.enqueue([&, h, task_type = stmt->task_type, context, this] {
FunctionType func;
while (true) {
std::unique_lock<std::mutex> lock(mut);
Expand All @@ -95,7 +107,6 @@ void ExecutionQueue::enqueue(KernelLaunchRecord &&ker) {
break;
}
stat.add("launched_kernels", 1.0);
auto task_type = stmt->task_type;
if (task_type == OffloadedStmt::TaskType::listgen) {
stat.add("launched_kernels_list_op", 1.0);
stat.add("launched_kernels_list_gen", 1.0);
Expand Down Expand Up @@ -140,18 +151,23 @@ void AsyncEngine::launch(Kernel *kernel) {
kmeta.dummy_root->kernel = kernel;
}
for (std::size_t i = 0; i < offloads.size(); i++) {
auto offload = offloads[i]->as<OffloadedStmt>();
auto cloned = clone_offloaded_task(offload, kernel, kmeta.dummy_root.get());
auto *offload = offloads[i]->as<OffloadedStmt>();
uint64 h;
OffloadedStmt *offl_template = nullptr;
if (kmeta_inited) {
h = kmeta.offloaded_hashes[i];
auto &oc = kmeta.offloaded_cached[i];
h = oc.get_hash();
offl_template = oc.get_template();
} else {
h = hash(cloned.get());
TI_ASSERT(kmeta.offloaded_hashes.size() == i);
kmeta.offloaded_hashes.push_back(h);
auto cloned_offs =
clone_offloaded_task(offload, kernel, kmeta.dummy_root.get());
offl_template = cloned_offs.get();
h = hash(offl_template);
TI_ASSERT(kmeta.offloaded_cached.size() == i);
kmeta.offloaded_cached.emplace_back(std::move(cloned_offs), h);
}
KernelLaunchRecord rec(kernel->program.get_context(), kernel,
std::move(cloned), h);
KernelLaunchRecord rec(kernel->program.get_context(), kernel, offl_template,
h, kmeta.dummy_root.get());
enqueue(std::move(rec));
}
}
Expand All @@ -161,7 +177,7 @@ void AsyncEngine::enqueue(KernelLaunchRecord &&t) {

auto &meta = offloaded_metas_[t.h];
// TODO: this is an abuse since it gathers nothing...
auto root_stmt = t.stmt;
auto *root_stmt = t.stmt();
gather_statements(root_stmt, [&](Stmt *stmt) {
if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
for (auto &snode : global_ptr->snodes.data) {
Expand Down Expand Up @@ -225,12 +241,12 @@ bool AsyncEngine::optimize_listgen() {
// Try to eliminate unused listgens
auto &t = task_queue[i];
auto meta = offloaded_metas_[t.h];
auto offload = t.stmt;
const auto *offload = t.stmt();
bool keep = true;
if (offload->task_type == OffloadedStmt::TaskType::listgen) {
// keep
} else if (offload->task_type == OffloadedStmt::TaskType::clear_list) {
TI_ASSERT(task_queue[i + 1].stmt->task_type ==
TI_ASSERT(task_queue[i + 1].stmt()->task_type ==
OffloadedStmt::TaskType::listgen);
auto snode = offload->snode;
if (list_dirty.find(snode) != list_dirty.end() && !list_dirty[snode]) {
Expand Down Expand Up @@ -266,16 +282,16 @@ bool AsyncEngine::fuse() {
if (false) {
// (experimental) print tasks
for (int i = 0; i < (int)task_queue.size(); i++) {
fmt::print("{}: {}\n", i, task_queue[i].stmt->task_name());
irpass::print(task_queue[i].stmt);
fmt::print("{}: {}\n", i, task_queue[i].stmt()->task_name());
irpass::print(task_queue[i].stmt());
}
}

for (int i = 0; i < (int)task_queue.size() - 1; i++) {
auto &rec_a = task_queue[i];
auto &rec_b = task_queue[i + 1];
auto task_a = rec_a.stmt;
auto task_b = rec_b.stmt;
auto *task_a = rec_a.stmt();
auto *task_b = rec_b.stmt();
bool is_same_struct_for = task_a->task_type == OffloadedStmt::struct_for &&
task_b->task_type == OffloadedStmt::struct_for &&
task_a->snode == task_b->snode &&
Expand Down Expand Up @@ -310,6 +326,9 @@ bool AsyncEngine::fuse() {
kernel_args_match = (check(rec_a.kernel) && check(rec_b.kernel));
}
if (kernel_args_match && (is_same_range_for || is_same_struct_for)) {
// We are about to change both |task_a| and |task_b|. Clone them first.
task_a = rec_a.clone_stmt_on_write();
task_b = rec_b.clone_stmt_on_write();
// TODO: in certain cases this optimization can be wrong!
// Fuse task b into task_a
for (int j = 0; j < (int)task_b->body->size(); j++) {
Expand All @@ -334,7 +353,7 @@ bool AsyncEngine::fuse() {

// Eliminate empty tasks
for (int i = 0; i < (int)task_queue.size(); i++) {
auto task = task_queue[i].stmt;
auto *task = task_queue[i].stmt();
bool keep = true;
if (task->task_type == OffloadedStmt::struct_for ||
task->task_type == OffloadedStmt::range_for ||
Expand Down
63 changes: 56 additions & 7 deletions taichi/program/async_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,33 @@ class KernelLaunchRecord {
public:
Context context;
Kernel *kernel; // TODO: remove this
OffloadedStmt *stmt;
uint64 h; // hash of |stmt|
uint64 h; // hash of |stmt|

KernelLaunchRecord(Context context,
Kernel *kernel,
std::unique_ptr<IRNode> &&stmt,
uint64 h);
OffloadedStmt *stmt,
uint64 h,
Block *dummy_root);

inline OffloadedStmt *stmt() {
return stmt_;
}

// When we need to make changes to |stmt|, call this method so that the |stmt|
// is cloned from the template, so that the template itself remains untouched.
//
// Cloning will only happen on the first call.
OffloadedStmt *clone_stmt_on_write();

private:
std::unique_ptr<IRNode> stmt_holder_;
// This begins as the template in OffloadedCachedData. If
// clone_stmt_on_write() is invoked, it points to the underlying pointer owned
// by |cloned_stmt_holder_|.
OffloadedStmt *stmt_;

// These are for cloning |stmt_|.
Block *dummy_root_; // Not owned
std::unique_ptr<OffloadedStmt> cloned_stmt_holder_;
};

// In charge of (parallel) compilation to binary and (serial) kernel launching
Expand Down Expand Up @@ -181,7 +198,40 @@ class AsyncEngine {
private:
struct KernelMeta {
std::unique_ptr<Block> dummy_root;
std::vector<uint64> offloaded_hashes;

// OffloadedCachedData holds some data that needs to be computed once for
// each offloaded task of a kernel. Especially, it holds a cloned offloaded
// task, but uses it as a READ-ONLY template. That is, code that later finds
// it necessary to mutate this task (e.g. kernel fusion) should do another
// clone, so that the template in this class stays untouched.
//
// This design allows us to do task cloning lazily. It turned out that doing
// clone on every kernel launch is too expensive.
struct OffloadedCachedData {
public:
explicit OffloadedCachedData(std::unique_ptr<OffloadedStmt> &&tmpl,
uint64 hash)
: tmpl_(std::move(tmpl)), hash_(hash) {
}

// Get the read-only offloaded task template. Ideally this should be a
// const pointer, but the IR passes won't work...
inline OffloadedStmt *get_template() {
return tmpl_.get();
}

inline uint64 get_hash() const {
return hash_;
}

private:
// Hide the unique pointer so that the ownership cannot be accidentally
// transferred.
std::unique_ptr<OffloadedStmt> tmpl_;
uint64 hash_;
};

std::vector<OffloadedCachedData> offloaded_cached;

inline bool initialized() const {
return dummy_root != nullptr;
Expand All @@ -197,7 +247,6 @@ class AsyncEngine {
// This map provides a dummy Block root for these OffloadedStmt, so that
// get_kernel() could still work correctly.
std::unordered_map<const Kernel *, KernelMeta> kernel_metas_;

std::unordered_map<std::uint64_t, TaskMeta> offloaded_metas_;
};

Expand Down
5 changes: 4 additions & 1 deletion taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ void compile_to_offloads(IRNode *ir,

print("Access flagged II");
irpass::analysis::verify(ir);

irpass::full_simplify(ir, /*after_lower_access=*/false);
print("Simplified III");
}

void compile_to_executable(IRNode *ir,
Expand Down Expand Up @@ -166,7 +169,7 @@ void compile_to_executable(IRNode *ir,
irpass::analysis::verify(ir);

irpass::full_simplify(ir, lower_global_access);
print("Simplified III");
print("Simplified IV");

// Final field registration correctness & type checking
irpass::typecheck(ir);
Expand Down

0 comments on commit 81437d6

Please sign in to comment.