Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] Move uniquely_accessed_bit_structs from compile_to_offloads to AnalysisManager #2264

Merged
merged 3 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions taichi/analysis/gather_uniquely_accessed_pointers.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "taichi/analysis/gather_uniquely_accessed_pointers.h"
#include "taichi/ir/ir.h"
#include "taichi/ir/analysis.h"
#include "taichi/ir/statements.h"
Expand Down Expand Up @@ -224,17 +225,19 @@ class UniquelyAccessedBitStructGatherer : public BasicStmtVisitor {
}
};

const std::string GatherUniquelyAccessedBitStructsPass::id =
"GatherUniquelyAccessedBitStructsPass";

namespace irpass::analysis {
std::unordered_map<const SNode *, GlobalPtrStmt *>
gather_uniquely_accessed_pointers(IRNode *root) {
// TODO: What about SNodeOpStmts?
return UniquelyAccessedSNodeSearcher::run(root);
}

std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
gather_uniquely_accessed_bit_structs(IRNode *root) {
return UniquelyAccessedBitStructGatherer::run(root);
void gather_uniquely_accessed_bit_structs(IRNode *root, AnalysisManager *amgr) {
amgr->put_pass_result<GatherUniquelyAccessedBitStructsPass>(
{UniquelyAccessedBitStructGatherer::run(root)});
}
} // namespace irpass::analysis

Expand Down
20 changes: 20 additions & 0 deletions taichi/analysis/gather_uniquely_accessed_pointers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include "taichi/ir/pass.h"

namespace taichi {
namespace lang {

class GatherUniquelyAccessedBitStructsPass : public Pass {
public:
static const PassID id;

struct Result {
std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
uniquely_accessed_bit_structs;
};
};

} // namespace lang
} // namespace taichi
6 changes: 3 additions & 3 deletions taichi/ir/analysis.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include "taichi/ir/ir.h"
#include "taichi/ir/pass.h"
#include "taichi/analysis/gather_uniquely_accessed_pointers.h"
#include <atomic>
#include <optional>
#include <unordered_set>
Expand Down Expand Up @@ -101,9 +103,7 @@ std::pair<std::unordered_set<SNode *>, std::unordered_set<SNode *>>
gather_snode_read_writes(IRNode *root);
std::vector<Stmt *> gather_statements(IRNode *root,
const std::function<bool(Stmt *)> &test);
std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
gather_uniquely_accessed_bit_structs(IRNode *root);
void gather_uniquely_accessed_bit_structs(IRNode *root, AnalysisManager *amgr);
std::unordered_map<const SNode *, GlobalPtrStmt *>
gather_uniquely_accessed_pointers(IRNode *root);
std::unique_ptr<std::unordered_set<AtomicOpStmt *>> gather_used_atomics(
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class AnalysisManager {
return nullptr;
}
using ResultModelT = AnalysisResultModel<typename PassT::Result>;
return &(static_cast<std::unique_ptr<ResultModelT>>(result.second)->result);
return &(static_cast<ResultModelT *>(result->second.get())->result);
}

template <typename PassT>
Expand Down
10 changes: 4 additions & 6 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "taichi/ir/control_flow_graph.h"
#include "taichi/ir/ir.h"
#include "taichi/ir/pass.h"
#include "taichi/transforms/check_out_of_bound.h"
#include "taichi/transforms/constant_fold.h"
#include "taichi/transforms/lower_access.h"
Expand Down Expand Up @@ -81,12 +82,9 @@ void demote_dense_struct_fors(IRNode *root);
bool demote_atomics(IRNode *root, const CompileConfig &config);
void reverse_segments(IRNode *root); // for autograd
void detect_read_only(IRNode *root);
void optimize_bit_struct_stores(
IRNode *root,
const CompileConfig &config,
const std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
&uniquely_accessed_bit_structs);
void optimize_bit_struct_stores(IRNode *root,
const CompileConfig &config,
AnalysisManager *amgr);

// compile_to_offloads does the basic compilation to create all the offloaded
// tasks of a Taichi kernel. It's worth pointing out that this doesn't demote
Expand Down
14 changes: 4 additions & 10 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ void compile_to_offloads(IRNode *ir,
auto print = make_pass_printer(verbose, kernel, ir);
print("Initial IR");

AnalysisManager amgr;
if (grad) {
// TODO(#2193): Support reverse_segments after lower_ast
TI_ASSERT_INFO(start_from_ast, "CHI does not support autodiff for now.");
irpass::reverse_segments(ir);
print("Segment reversed (for autodiff)");
}
Expand Down Expand Up @@ -155,6 +152,8 @@ void offload_to_executable(IRNode *ir,
// For now, putting this after TLS will disable TLS, because it can only
// handle range-fors at this point.

auto amgr = std::make_unique<AnalysisManager>();

print("Start offload_to_executable");
irpass::analysis::verify(ir);

Expand Down Expand Up @@ -188,13 +187,9 @@ void offload_to_executable(IRNode *ir,
print("Atomics demoted II");
irpass::analysis::verify(ir);

std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
uniquely_accessed_bit_structs;
if (is_extension_supported(config.arch, Extension::quant) &&
ir->get_config().quant_opt_atomic_demotion) {
uniquely_accessed_bit_structs =
irpass::analysis::gather_uniquely_accessed_bit_structs(ir);
irpass::analysis::gather_uniquely_accessed_bit_structs(ir, amgr.get());
}

irpass::remove_range_assumption(ir);
Expand Down Expand Up @@ -225,8 +220,7 @@ void offload_to_executable(IRNode *ir,
print("Simplified IV");

if (is_extension_supported(config.arch, Extension::quant)) {
irpass::optimize_bit_struct_stores(ir, config,
uniquely_accessed_bit_structs);
irpass::optimize_bit_struct_stores(ir, config, amgr.get());
print("Bit struct stores optimized");
}

Expand Down
17 changes: 10 additions & 7 deletions taichi/transforms/optimize_bit_struct_stores.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "taichi/ir/analysis.h"
#include "taichi/ir/ir.h"
#include "taichi/ir/pass.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
Expand Down Expand Up @@ -200,20 +201,22 @@ class DemoteAtomicBitStructStores : public BasicStmtVisitor {
TLANG_NAMESPACE_BEGIN

namespace irpass {
void optimize_bit_struct_stores(
IRNode *root,
const CompileConfig &config,
const std::unordered_map<OffloadedStmt *,
std::unordered_map<const SNode *, GlobalPtrStmt *>>
&uniquely_accessed_bit_structs) {
void optimize_bit_struct_stores(IRNode *root,
const CompileConfig &config,
AnalysisManager *amgr) {
TI_AUTO_PROF;
CreateBitStructStores::run(root);
die(root); // remove unused GetCh
if (config.quant_opt_store_fusion) {
MergeBitStructStores::run(root);
}
if (config.quant_opt_atomic_demotion) {
DemoteAtomicBitStructStores::run(root, uniquely_accessed_bit_structs);
auto *res = amgr->get_pass_result<GatherUniquelyAccessedBitStructsPass>();
TI_ASSERT_INFO(res,
"The optimize_bit_struct_stores pass must be after the "
"gather_uniquely_accessed_bit_structs pass when "
"config.quant_opt_atomic_demotion is true.");
DemoteAtomicBitStructStores::run(root, res->uniquely_accessed_bit_structs);
}
}

Expand Down