Skip to content

Commit

Permalink
[async] Cache fusion results (#1875)
Browse files Browse the repository at this point in the history
  • Loading branch information
xumingkuan authored Sep 15, 2020
1 parent edb8ed7 commit 7972590
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 31 deletions.
38 changes: 38 additions & 0 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,44 @@ IRNode *IRBank::find(IRHandle ir_handle) {
return result->second.get();
}

IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) {
auto &result = fuse_bank_[std::make_pair(handle_a, handle_b)];
if (!result.empty()) {
// assume the kernel is always the same when the ir handles are the same
return result;
}

TI_INFO("Begin uncached fusion");
// We are about to change both |task_a| and |task_b|. Clone them first.
auto cloned_task_a = handle_a.clone();
auto cloned_task_b = handle_b.clone();
auto task_a = cloned_task_a->as<OffloadedStmt>();
auto task_b = cloned_task_b->as<OffloadedStmt>();
// TODO: in certain cases this optimization can be wrong!
// Fuse task b into task_a
for (int j = 0; j < (int)task_b->body->size(); j++) {
task_a->body->insert(std::move(task_b->body->statements[j]));
}
task_b->body->statements.clear();

// replace all reference to the offloaded statement B to A
irpass::replace_all_usages_with(task_a, task_b, task_a);

irpass::full_simplify(task_a, /*after_lower_access=*/false, kernel);
// For now, re_id is necessary for the hash to be correct.
irpass::re_id(task_a);

auto h = get_hash(task_a);
result = IRHandle(task_a, h);
insert(std::move(cloned_task_a), h);

// TODO: since cloned_task_b->body is empty, can we remove this (i.e.,
// simply delete cloned_task_b here)?
insert_to_trash_bin(std::move(cloned_task_b));

return result;
}

ParallelExecutor::ParallelExecutor(int num_threads)
: num_threads(num_threads),
status(ExecutorStatus::uninitialized),
Expand Down
6 changes: 4 additions & 2 deletions taichi/program/async_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ class IRBank {
void insert_to_trash_bin(std::unique_ptr<IRNode> &&ir);
IRNode *find(IRHandle ir_handle);

// Fuse handle_b into handle_a
IRHandle fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel);

std::unordered_map<IRHandle, TaskMeta> meta_bank_;

private:
std::unordered_map<IRNode *, uint64> hash_bank_;
std::unordered_map<IRHandle, std::unique_ptr<IRNode>> ir_bank_;
std::vector<std::unique_ptr<IRNode>> trash_bin; // prevent IR from deleted
// TODO:
// std::unordered_map<std::pair<IRHandle, IRHandle>, IRHandle> fuse_bank_;
std::unordered_map<std::pair<IRHandle, IRHandle>, IRHandle> fuse_bank_;
};

class ParallelExecutor {
Expand Down
16 changes: 16 additions & 0 deletions taichi/program/async_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ struct TaskMeta;

class IRHandle {
public:
IRHandle() : ir_(nullptr), hash_(0) {
}

IRHandle(const IRNode *ir, uint64 hash) : ir_(ir), hash_(hash) {
}

Expand All @@ -28,6 +31,10 @@ class IRHandle {
return hash_;
}

bool empty() const {
return ir_ == nullptr;
}

// Two IRHandles are considered the same iff their hash values are the same.
bool operator==(const IRHandle &other_ir_handle) const {
return hash_ == other_ir_handle.hash_;
Expand All @@ -48,6 +55,15 @@ struct hash<taichi::lang::IRHandle> {
return ir_handle.hash();
}
};

template <>
struct hash<std::pair<taichi::lang::IRHandle, taichi::lang::IRHandle>> {
std::size_t operator()(
const std::pair<taichi::lang::IRHandle, taichi::lang::IRHandle>
&ir_handles) const noexcept {
return ir_handles.first.hash() * 100000007UL + ir_handles.second.hash();
}
};
} // namespace std

TLANG_NAMESPACE_BEGIN
Expand Down
34 changes: 5 additions & 29 deletions taichi/program/state_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,38 +302,14 @@ bool StateFlowGraph::fuse() {
auto *node_a = nodes_[a].get();
auto *node_b = nodes_[b].get();
// TODO: remove debug output
TI_INFO("Fuse: {} <- {}", node_a->string(), node_b->string());
TI_TRACE("Fuse: {} <- {}", node_a->string(), node_b->string());
auto &rec_a = node_a->rec;
auto &rec_b = node_b->rec;
// We are about to change both |task_a| and |task_b|. Clone them first.
auto cloned_task_a = rec_a.ir_handle.clone();
auto cloned_task_b = rec_b.ir_handle.clone();
auto task_a = cloned_task_a->as<OffloadedStmt>();
auto task_b = cloned_task_b->as<OffloadedStmt>();
// TODO: in certain cases this optimization can be wrong!
// Fuse task b into task_a
for (int j = 0; j < (int)task_b->body->size(); j++) {
task_a->body->insert(std::move(task_b->body->statements[j]));
}
task_b->body->statements.clear();

// replace all reference to the offloaded statement B to A
irpass::replace_all_usages_with(task_a, task_b, task_a);

auto kernel = rec_a.kernel;
irpass::full_simplify(task_a, /*after_lower_access=*/false, kernel);
// For now, re_id is necessary for the hash to be correct.
irpass::re_id(task_a);

auto h = ir_bank_->get_hash(task_a);
rec_a.ir_handle = IRHandle(task_a, h);
ir_bank_->insert(std::move(cloned_task_a), h);
rec_b.ir_handle = IRHandle(nullptr, 0);
indices_to_delete.insert(b);
rec_a.ir_handle =
ir_bank_->fuse(rec_a.ir_handle, rec_b.ir_handle, rec_a.kernel);
rec_b.ir_handle = IRHandle();

// TODO: since cloned_task_b->body is empty, can we remove this (i.e.,
// simply delete cloned_task_b here)?
ir_bank_->insert_to_trash_bin(std::move(cloned_task_b));
indices_to_delete.insert(b);

const bool already_had_a_to_b_edge = has_path[a][b];
if (already_had_a_to_b_edge) {
Expand Down

0 comments on commit 7972590

Please sign in to comment.