Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[async] Clone offloaded tasks lazily by caching the AST #1619

Merged
merged 1 commit into from
Aug 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion taichi/analysis/clone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ class IRCloner : public IRVisitor {
root->accept(&cloner);

using namespace irpass;
typecheck(new_root.get());
fix_block_parents(new_root.get());
fix_root_block_kernel(new_root.get(), kernel);
typecheck(new_root.get());
return new_root;
}
};
Expand Down
77 changes: 48 additions & 29 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,51 @@ uint64 hash(IRNode *stmt) {
return ret;
}

std::unique_ptr<IRNode> clone_offloaded_task(OffloadedStmt *from,
Kernel *kernel,
Block *dummy_root) {
std::unique_ptr<OffloadedStmt> clone_offloaded_task(OffloadedStmt *from,
Kernel *kernel,
Block *dummy_root) {
auto new_ir = irpass::analysis::clone(from, kernel);
// This is not the ideal fix, because |new_ir|'s children blocks are NOT
// linked to |dummy_root|. However, if I manually do the linking, I got error
// during LLVM codegen.
new_ir->as<OffloadedStmt>()->parent = dummy_root;
return new_ir;
return std::unique_ptr<OffloadedStmt>((OffloadedStmt *)(new_ir.release()));
}

} // namespace

KernelLaunchRecord::KernelLaunchRecord(Context context,
Kernel *kernel,
std::unique_ptr<IRNode> &&stmt_,
uint64 h)
OffloadedStmt *stmt,
uint64 h,
Block *dummy_root)
: context(context),
kernel(kernel),
stmt(dynamic_cast<OffloadedStmt *>(stmt_.get())),
h(h),
stmt_holder_(std::move(stmt_)) {
TI_ASSERT(stmt != nullptr);
TI_ASSERT(stmt->get_kernel() != nullptr);
stmt_(stmt),
dummy_root_(dummy_root),
cloned_stmt_holder_(nullptr) {
TI_ASSERT(stmt_ != nullptr);
TI_ASSERT(stmt_->get_kernel() != nullptr);
}

OffloadedStmt *KernelLaunchRecord::clone_stmt_on_write() {
if (cloned_stmt_holder_ == nullptr) {
cloned_stmt_holder_ = clone_offloaded_task(stmt_, kernel, dummy_root_);
stmt_ = cloned_stmt_holder_.get();
}
return stmt_;
}

void ExecutionQueue::enqueue(KernelLaunchRecord &&ker) {
auto h = ker.h;
auto stmt = ker.stmt;
auto *stmt = ker.stmt();
auto kernel = ker.kernel;
if (compiled_func.find(h) == compiled_func.end() &&
to_be_compiled.find(h) == to_be_compiled.end()) {
to_be_compiled.insert(h);
// Later the IR passes will change |stmt|, so we must clone it.
stmt = ker.clone_stmt_on_write();
compilation_workers.enqueue([&, stmt, kernel, h, this]() {
{
// Final lowering
Expand All @@ -82,7 +94,7 @@ void ExecutionQueue::enqueue(KernelLaunchRecord &&ker) {
}

auto context = ker.context;
launch_worker.enqueue([&, h, stmt, context, this] {
launch_worker.enqueue([&, h, task_type = stmt->task_type, context, this] {
FunctionType func;
while (true) {
std::unique_lock<std::mutex> lock(mut);
Expand All @@ -95,7 +107,6 @@ void ExecutionQueue::enqueue(KernelLaunchRecord &&ker) {
break;
}
stat.add("launched_kernels", 1.0);
auto task_type = stmt->task_type;
if (task_type == OffloadedStmt::TaskType::listgen) {
stat.add("launched_kernels_list_op", 1.0);
stat.add("launched_kernels_list_gen", 1.0);
Expand Down Expand Up @@ -140,18 +151,23 @@ void AsyncEngine::launch(Kernel *kernel) {
kmeta.dummy_root->kernel = kernel;
}
for (std::size_t i = 0; i < offloads.size(); i++) {
auto offload = offloads[i]->as<OffloadedStmt>();
auto cloned = clone_offloaded_task(offload, kernel, kmeta.dummy_root.get());
auto *offload = offloads[i]->as<OffloadedStmt>();
uint64 h;
OffloadedStmt *offl_template = nullptr;
if (kmeta_inited) {
h = kmeta.offloaded_hashes[i];
auto &oc = kmeta.offloaded_cached[i];
h = oc.get_hash();
offl_template = oc.get_template();
} else {
h = hash(cloned.get());
TI_ASSERT(kmeta.offloaded_hashes.size() == i);
kmeta.offloaded_hashes.push_back(h);
auto cloned_offs =
clone_offloaded_task(offload, kernel, kmeta.dummy_root.get());
offl_template = cloned_offs.get();
h = hash(offl_template);
TI_ASSERT(kmeta.offloaded_cached.size() == i);
kmeta.offloaded_cached.emplace_back(std::move(cloned_offs), h);
}
KernelLaunchRecord rec(kernel->program.get_context(), kernel,
std::move(cloned), h);
KernelLaunchRecord rec(kernel->program.get_context(), kernel, offl_template,
h, kmeta.dummy_root.get());
enqueue(std::move(rec));
}
}
Expand All @@ -161,7 +177,7 @@ void AsyncEngine::enqueue(KernelLaunchRecord &&t) {

auto &meta = offloaded_metas_[t.h];
// TODO: this is an abuse since it gathers nothing...
auto root_stmt = t.stmt;
auto *root_stmt = t.stmt();
gather_statements(root_stmt, [&](Stmt *stmt) {
if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
for (auto &snode : global_ptr->snodes.data) {
Expand Down Expand Up @@ -225,12 +241,12 @@ bool AsyncEngine::optimize_listgen() {
// Try to eliminate unused listgens
auto &t = task_queue[i];
auto meta = offloaded_metas_[t.h];
auto offload = t.stmt;
const auto *offload = t.stmt();
bool keep = true;
if (offload->task_type == OffloadedStmt::TaskType::listgen) {
// keep
} else if (offload->task_type == OffloadedStmt::TaskType::clear_list) {
TI_ASSERT(task_queue[i + 1].stmt->task_type ==
TI_ASSERT(task_queue[i + 1].stmt()->task_type ==
OffloadedStmt::TaskType::listgen);
auto snode = offload->snode;
if (list_dirty.find(snode) != list_dirty.end() && !list_dirty[snode]) {
Expand Down Expand Up @@ -266,16 +282,16 @@ bool AsyncEngine::fuse() {
if (false) {
// (experimental) print tasks
for (int i = 0; i < (int)task_queue.size(); i++) {
fmt::print("{}: {}\n", i, task_queue[i].stmt->task_name());
irpass::print(task_queue[i].stmt);
fmt::print("{}: {}\n", i, task_queue[i].stmt()->task_name());
irpass::print(task_queue[i].stmt());
}
}

for (int i = 0; i < (int)task_queue.size() - 1; i++) {
auto &rec_a = task_queue[i];
auto &rec_b = task_queue[i + 1];
auto task_a = rec_a.stmt;
auto task_b = rec_b.stmt;
auto *task_a = rec_a.stmt();
auto *task_b = rec_b.stmt();
bool is_same_struct_for = task_a->task_type == OffloadedStmt::struct_for &&
task_b->task_type == OffloadedStmt::struct_for &&
task_a->snode == task_b->snode &&
Expand Down Expand Up @@ -310,6 +326,9 @@ bool AsyncEngine::fuse() {
kernel_args_match = (check(rec_a.kernel) && check(rec_b.kernel));
}
if (kernel_args_match && (is_same_range_for || is_same_struct_for)) {
// We are about to change both |task_a| and |task_b|. Clone them first.
task_a = rec_a.clone_stmt_on_write();
task_b = rec_b.clone_stmt_on_write();
// TODO: in certain cases this optimization can be wrong!
// Fuse task b into task_a
for (int j = 0; j < (int)task_b->body->size(); j++) {
Expand All @@ -334,7 +353,7 @@ bool AsyncEngine::fuse() {

// Eliminate empty tasks
for (int i = 0; i < (int)task_queue.size(); i++) {
auto task = task_queue[i].stmt;
auto *task = task_queue[i].stmt();
bool keep = true;
if (task->task_type == OffloadedStmt::struct_for ||
task->task_type == OffloadedStmt::range_for ||
Expand Down
63 changes: 56 additions & 7 deletions taichi/program/async_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,33 @@ class KernelLaunchRecord {
public:
Context context;
Kernel *kernel; // TODO: remove this
OffloadedStmt *stmt;
uint64 h; // hash of |stmt|
uint64 h; // hash of |stmt|

KernelLaunchRecord(Context context,
Kernel *kernel,
std::unique_ptr<IRNode> &&stmt,
uint64 h);
OffloadedStmt *stmt,
uint64 h,
Block *dummy_root);

inline OffloadedStmt *stmt() {
return stmt_;
}

// When we need to make changes to |stmt|, call this method so that the |stmt|
// is cloned from the template, so that the template itself remains untouched.
//
// Cloning will only happen on the first call.
OffloadedStmt *clone_stmt_on_write();

private:
std::unique_ptr<IRNode> stmt_holder_;
// This begins as the template in OffloadedCachedData. If
// clone_stmt_on_write() is invoked, it points to the underlying pointer owned
// by |cloned_stmt_holder_|.
OffloadedStmt *stmt_;

// These are for cloning |stmt_|.
Block *dummy_root_; // Not owned
std::unique_ptr<OffloadedStmt> cloned_stmt_holder_;
};

// In charge of (parallel) compilation to binary and (serial) kernel launching
Expand Down Expand Up @@ -181,7 +198,40 @@ class AsyncEngine {
private:
struct KernelMeta {
std::unique_ptr<Block> dummy_root;
std::vector<uint64> offloaded_hashes;

// OffloadedCachedData holds some data that needs to be computed once for
// each offloaded task of a kernel. Especially, it holds a cloned offloaded
// task, but uses it as a READ-ONLY template. That is, code that later finds
// it necessary to mutate this task (e.g. kernel fusion) should do another
// clone, so that the template in this class stays untouched.
//
// This design allows us to do task cloning lazily. It turned out that doing
// clone on every kernel launch is too expensive.
struct OffloadedCachedData {
public:
explicit OffloadedCachedData(std::unique_ptr<OffloadedStmt> &&tmpl,
uint64 hash)
: tmpl_(std::move(tmpl)), hash_(hash) {
}

// Get the read-only offloaded task template. Ideally this should be a
// const pointer, but the IR passes won't work...
inline OffloadedStmt *get_template() {
return tmpl_.get();
}

inline uint64 get_hash() const {
return hash_;
}

private:
// Hide the unique pointer so that the ownership cannot be accidentally
// transferred.
std::unique_ptr<OffloadedStmt> tmpl_;
uint64 hash_;
};

std::vector<OffloadedCachedData> offloaded_cached;

inline bool initialized() const {
return dummy_root != nullptr;
Expand All @@ -197,7 +247,6 @@ class AsyncEngine {
// This map provides a dummy Block root for these OffloadedStmt, so that
// get_kernel() could still work correctly.
std::unordered_map<const Kernel *, KernelMeta> kernel_metas_;

std::unordered_map<std::uint64_t, TaskMeta> offloaded_metas_;
};

Expand Down
5 changes: 4 additions & 1 deletion taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ void compile_to_offloads(IRNode *ir,

print("Access flagged II");
irpass::analysis::verify(ir);

irpass::full_simplify(ir, /*after_lower_access=*/false);
print("Simplified III");
}

void compile_to_executable(IRNode *ir,
Expand Down Expand Up @@ -166,7 +169,7 @@ void compile_to_executable(IRNode *ir,
irpass::analysis::verify(ir);

irpass::full_simplify(ir, lower_global_access);
print("Simplified III");
print("Simplified IV");

// Final field registration correctness & type checking
irpass::typecheck(ir);
Expand Down