Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[async] Cache fusion results #1875

Merged
merged 1 commit into from
Sep 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,44 @@ IRNode *IRBank::find(IRHandle ir_handle) {
return result->second.get();
}

IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) {
auto &result = fuse_bank_[std::make_pair(handle_a, handle_b)];
if (!result.empty()) {
// assume the kernel is always the same when the ir handles are the same
return result;
}

TI_INFO("Begin uncached fusion");
// We are about to change both |task_a| and |task_b|. Clone them first.
auto cloned_task_a = handle_a.clone();
auto cloned_task_b = handle_b.clone();
auto task_a = cloned_task_a->as<OffloadedStmt>();
auto task_b = cloned_task_b->as<OffloadedStmt>();
// TODO: in certain cases this optimization can be wrong!
// Fuse task b into task_a
for (int j = 0; j < (int)task_b->body->size(); j++) {
task_a->body->insert(std::move(task_b->body->statements[j]));
}
task_b->body->statements.clear();

// replace all reference to the offloaded statement B to A
irpass::replace_all_usages_with(task_a, task_b, task_a);

irpass::full_simplify(task_a, /*after_lower_access=*/false, kernel);
// For now, re_id is necessary for the hash to be correct.
irpass::re_id(task_a);

auto h = get_hash(task_a);
result = IRHandle(task_a, h);
insert(std::move(cloned_task_a), h);

// TODO: since cloned_task_b->body is empty, can we remove this (i.e.,
// simply delete cloned_task_b here)?
insert_to_trash_bin(std::move(cloned_task_b));

return result;
}

ParallelExecutor::ParallelExecutor(int num_threads)
: num_threads(num_threads),
status(ExecutorStatus::uninitialized),
Expand Down
6 changes: 4 additions & 2 deletions taichi/program/async_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ class IRBank {
void insert_to_trash_bin(std::unique_ptr<IRNode> &&ir);
IRNode *find(IRHandle ir_handle);

// Fuse handle_b into handle_a
IRHandle fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel);

std::unordered_map<IRHandle, TaskMeta> meta_bank_;

private:
std::unordered_map<IRNode *, uint64> hash_bank_;
std::unordered_map<IRHandle, std::unique_ptr<IRNode>> ir_bank_;
std::vector<std::unique_ptr<IRNode>> trash_bin; // prevent IR from deleted
// TODO:
// std::unordered_map<std::pair<IRHandle, IRHandle>, IRHandle> fuse_bank_;
std::unordered_map<std::pair<IRHandle, IRHandle>, IRHandle> fuse_bank_;
};

class ParallelExecutor {
Expand Down
16 changes: 16 additions & 0 deletions taichi/program/async_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ struct TaskMeta;

class IRHandle {
public:
IRHandle() : ir_(nullptr), hash_(0) {
}

IRHandle(const IRNode *ir, uint64 hash) : ir_(ir), hash_(hash) {
}

Expand All @@ -28,6 +31,10 @@ class IRHandle {
return hash_;
}

bool empty() const {
return ir_ == nullptr;
}

// Two IRHandles are considered the same iff their hash values are the same.
bool operator==(const IRHandle &other_ir_handle) const {
return hash_ == other_ir_handle.hash_;
Expand All @@ -48,6 +55,15 @@ struct hash<taichi::lang::IRHandle> {
return ir_handle.hash();
}
};

template <>
struct hash<std::pair<taichi::lang::IRHandle, taichi::lang::IRHandle>> {
std::size_t operator()(
const std::pair<taichi::lang::IRHandle, taichi::lang::IRHandle>
&ir_handles) const noexcept {
return ir_handles.first.hash() * 100000007UL + ir_handles.second.hash();
}
};
} // namespace std

TLANG_NAMESPACE_BEGIN
Expand Down
34 changes: 5 additions & 29 deletions taichi/program/state_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,38 +302,14 @@ bool StateFlowGraph::fuse() {
auto *node_a = nodes_[a].get();
auto *node_b = nodes_[b].get();
// TODO: remove debug output
TI_INFO("Fuse: {} <- {}", node_a->string(), node_b->string());
TI_TRACE("Fuse: {} <- {}", node_a->string(), node_b->string());
auto &rec_a = node_a->rec;
auto &rec_b = node_b->rec;
// We are about to change both |task_a| and |task_b|. Clone them first.
auto cloned_task_a = rec_a.ir_handle.clone();
auto cloned_task_b = rec_b.ir_handle.clone();
auto task_a = cloned_task_a->as<OffloadedStmt>();
auto task_b = cloned_task_b->as<OffloadedStmt>();
// TODO: in certain cases this optimization can be wrong!
// Fuse task b into task_a
for (int j = 0; j < (int)task_b->body->size(); j++) {
task_a->body->insert(std::move(task_b->body->statements[j]));
}
task_b->body->statements.clear();

// replace all reference to the offloaded statement B to A
irpass::replace_all_usages_with(task_a, task_b, task_a);

auto kernel = rec_a.kernel;
irpass::full_simplify(task_a, /*after_lower_access=*/false, kernel);
// For now, re_id is necessary for the hash to be correct.
irpass::re_id(task_a);

auto h = ir_bank_->get_hash(task_a);
rec_a.ir_handle = IRHandle(task_a, h);
ir_bank_->insert(std::move(cloned_task_a), h);
rec_b.ir_handle = IRHandle(nullptr, 0);
indices_to_delete.insert(b);
rec_a.ir_handle =
ir_bank_->fuse(rec_a.ir_handle, rec_b.ir_handle, rec_a.kernel);
rec_b.ir_handle = IRHandle();

// TODO: since cloned_task_b->body is empty, can we remove this (i.e.,
// simply delete cloned_task_b here)?
ir_bank_->insert_to_trash_bin(std::move(cloned_task_b));
indices_to_delete.insert(b);

const bool already_had_a_to_b_edge = has_path[a][b];
if (already_had_a_to_b_edge) {
Expand Down