diff --git a/taichi/analysis/gather_uniquely_accessed_pointers.cpp b/taichi/analysis/gather_uniquely_accessed_pointers.cpp index b48a02fb6cb8d..8febbe134ba36 100644 --- a/taichi/analysis/gather_uniquely_accessed_pointers.cpp +++ b/taichi/analysis/gather_uniquely_accessed_pointers.cpp @@ -1,3 +1,4 @@ +#include "taichi/analysis/gather_uniquely_accessed_pointers.h" #include "taichi/ir/ir.h" #include "taichi/ir/analysis.h" #include "taichi/ir/statements.h" @@ -224,6 +225,9 @@ class UniquelyAccessedBitStructGatherer : public BasicStmtVisitor { } }; +const std::string GatherUniquelyAccessedBitStructsPass::id = + "GatherUniquelyAccessedBitStructsPass"; + namespace irpass::analysis { std::unordered_map gather_uniquely_accessed_pointers(IRNode *root) { @@ -231,10 +235,9 @@ gather_uniquely_accessed_pointers(IRNode *root) { return UniquelyAccessedSNodeSearcher::run(root); } -std::unordered_map> -gather_uniquely_accessed_bit_structs(IRNode *root) { - return UniquelyAccessedBitStructGatherer::run(root); +void gather_uniquely_accessed_bit_structs(IRNode *root, AnalysisManager *amgr) { + amgr->put_pass_result( + {UniquelyAccessedBitStructGatherer::run(root)}); } } // namespace irpass::analysis diff --git a/taichi/analysis/gather_uniquely_accessed_pointers.h b/taichi/analysis/gather_uniquely_accessed_pointers.h new file mode 100644 index 0000000000000..2314a5c5b8ede --- /dev/null +++ b/taichi/analysis/gather_uniquely_accessed_pointers.h @@ -0,0 +1,20 @@ +#pragma once + +#include "taichi/ir/pass.h" + +namespace taichi { +namespace lang { + +class GatherUniquelyAccessedBitStructsPass : public Pass { + public: + static const PassID id; + + struct Result { + std::unordered_map> + uniquely_accessed_bit_structs; + }; +}; + +} // namespace lang +} // namespace taichi diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index b08c29d43b269..5e53cb16c714d 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -1,6 +1,8 @@ #pragma once #include "taichi/ir/ir.h" +#include "taichi/ir/pass.h" +#include "taichi/analysis/gather_uniquely_accessed_pointers.h" #include #include #include @@ -101,9 +103,7 @@ std::pair, std::unordered_set> gather_snode_read_writes(IRNode *root); std::vector gather_statements(IRNode *root, const std::function &test); -std::unordered_map> -gather_uniquely_accessed_bit_structs(IRNode *root); +void gather_uniquely_accessed_bit_structs(IRNode *root, AnalysisManager *amgr); std::unordered_map gather_uniquely_accessed_pointers(IRNode *root); std::unique_ptr> gather_used_atomics( diff --git a/taichi/ir/pass.h b/taichi/ir/pass.h index 9a893ce148dd1..5001157627216 100644 --- a/taichi/ir/pass.h +++ b/taichi/ir/pass.h @@ -70,7 +70,7 @@ class AnalysisManager { return nullptr; } using ResultModelT = AnalysisResultModel; - return &(static_cast>(result.second)->result); + return &(static_cast(result->second.get())->result); } template diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 5055a66b4903b..13758f8337fd2 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -7,6 +7,7 @@ #include "taichi/ir/control_flow_graph.h" #include "taichi/ir/ir.h" +#include "taichi/ir/pass.h" #include "taichi/transforms/check_out_of_bound.h" #include "taichi/transforms/constant_fold.h" #include "taichi/transforms/lower_access.h" @@ -81,12 +82,9 @@ void demote_dense_struct_fors(IRNode *root); bool demote_atomics(IRNode *root, const CompileConfig &config); void reverse_segments(IRNode *root); // for autograd void detect_read_only(IRNode *root); -void optimize_bit_struct_stores( - IRNode *root, - const CompileConfig &config, - const std::unordered_map> - &uniquely_accessed_bit_structs); +void optimize_bit_struct_stores(IRNode *root, + const CompileConfig &config, + AnalysisManager *amgr); // compile_to_offloads does the basic compilation to create all the offloaded // tasks of a Taichi kernel. It's worth pointing out that this doesn't demote diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 0b68cc97e0e37..a868032b9cad5 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -42,10 +42,7 @@ void compile_to_offloads(IRNode *ir, auto print = make_pass_printer(verbose, kernel, ir); print("Initial IR"); - AnalysisManager amgr; if (grad) { - // TODO(#2193): Support reverse_segments after lower_ast - TI_ASSERT_INFO(start_from_ast, "CHI does not support autodiff for now."); irpass::reverse_segments(ir); print("Segment reversed (for autodiff)"); } @@ -155,6 +152,8 @@ void offload_to_executable(IRNode *ir, // For now, putting this after TLS will disable TLS, because it can only // handle range-fors at this point. + auto amgr = std::make_unique(); + print("Start offload_to_executable"); irpass::analysis::verify(ir); @@ -188,13 +187,9 @@ void offload_to_executable(IRNode *ir, print("Atomics demoted II"); irpass::analysis::verify(ir); - std::unordered_map> - uniquely_accessed_bit_structs; if (is_extension_supported(config.arch, Extension::quant) && ir->get_config().quant_opt_atomic_demotion) { - uniquely_accessed_bit_structs = - irpass::analysis::gather_uniquely_accessed_bit_structs(ir); + irpass::analysis::gather_uniquely_accessed_bit_structs(ir, amgr.get()); } irpass::remove_range_assumption(ir); @@ -225,8 +220,7 @@ void offload_to_executable(IRNode *ir, print("Simplified IV"); if (is_extension_supported(config.arch, Extension::quant)) { - irpass::optimize_bit_struct_stores(ir, config, - uniquely_accessed_bit_structs); + irpass::optimize_bit_struct_stores(ir, config, amgr.get()); print("Bit struct stores optimized"); } diff --git a/taichi/transforms/optimize_bit_struct_stores.cpp b/taichi/transforms/optimize_bit_struct_stores.cpp index 49586c057d6d9..0a133ff036ef4 100644 --- a/taichi/transforms/optimize_bit_struct_stores.cpp +++ b/taichi/transforms/optimize_bit_struct_stores.cpp @@ -1,5 +1,6 @@ #include "taichi/ir/analysis.h" #include "taichi/ir/ir.h" +#include "taichi/ir/pass.h" #include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/ir/visitors.h" @@ -200,12 +201,9 @@ class DemoteAtomicBitStructStores : public BasicStmtVisitor { TLANG_NAMESPACE_BEGIN namespace irpass { -void optimize_bit_struct_stores( - IRNode *root, - const CompileConfig &config, - const std::unordered_map> - &uniquely_accessed_bit_structs) { +void optimize_bit_struct_stores(IRNode *root, + const CompileConfig &config, + AnalysisManager *amgr) { TI_AUTO_PROF; CreateBitStructStores::run(root); die(root); // remove unused GetCh @@ -213,7 +211,12 @@ void optimize_bit_struct_stores( MergeBitStructStores::run(root); } if (config.quant_opt_atomic_demotion) { - DemoteAtomicBitStructStores::run(root, uniquely_accessed_bit_structs); + auto *res = amgr->get_pass_result(); + TI_ASSERT_INFO(res, + "The optimize_bit_struct_stores pass must be after the " + "gather_uniquely_accessed_bit_structs pass when " + "config.quant_opt_atomic_demotion is true."); + DemoteAtomicBitStructStores::run(root, res->uniquely_accessed_bit_structs); } }