diff --git a/taichi/analysis/clone.cpp b/taichi/analysis/clone.cpp index 751b0086e4795..5a787bd19338e 100644 --- a/taichi/analysis/clone.cpp +++ b/taichi/analysis/clone.cpp @@ -112,7 +112,7 @@ class IRCloner : public IRVisitor { static std::unique_ptr run(IRNode *root, Kernel *kernel) { if (kernel == nullptr) { - kernel = &get_current_program().get_current_kernel(); + kernel = root->get_kernel(); } std::unique_ptr new_root = root->clone(); IRCloner cloner(new_root.get()); @@ -120,7 +120,7 @@ class IRCloner : public IRVisitor { root->accept(&cloner); cloner.phase = IRCloner::replace_operand; root->accept(&cloner); - irpass::typecheck(new_root.get(), kernel); + irpass::typecheck(new_root.get()); irpass::fix_block_parents(new_root.get()); return new_root; } diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index e8c1cf17495aa..c5379bd889792 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -1031,7 +1031,7 @@ CodeGen::CodeGen(Kernel *kernel, FunctionType CodeGen::compile() { auto &config = kernel_->program.config; config.demote_dense_struct_fors = true; - irpass::compile_to_offloads(kernel_->ir, config, + irpass::compile_to_offloads(kernel_->ir.get(), config, /*vectorize=*/false, kernel_->grad, /*ad_use_stack=*/false, config.print_ir); diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index 77540a28d6a7e..a6643505e4faf 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -702,7 +702,7 @@ FunctionType OpenglCodeGen::gen(void) { } void OpenglCodeGen::lower() { - auto ir = kernel_->ir; + auto ir = kernel_->ir.get(); auto &config = kernel_->program.config; config.demote_dense_struct_fors = true; irpass::compile_to_offloads(ir, config, diff --git a/taichi/codegen/codegen.cpp b/taichi/codegen/codegen.cpp index 11a44e7643ae6..a2846ddedee61 100644 --- a/taichi/codegen/codegen.cpp +++ b/taichi/codegen/codegen.cpp @@ -15,7 +15,7 @@ TLANG_NAMESPACE_BEGIN KernelCodeGen::KernelCodeGen(Kernel *kernel, IRNode *ir) : prog(&kernel->program), kernel(kernel), ir(ir) { if (ir == nullptr) - this->ir = kernel->ir; + this->ir = kernel->ir.get(); auto num_stmts = irpass::analysis::count_statements(this->ir); if (kernel->is_evaluator) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 3b27efc62b05c..3c05a4498c4ad 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -272,7 +272,7 @@ CodeGenLLVM::CodeGenLLVM(Kernel *kernel, IRNode *ir) ir(ir), prog(&kernel->program) { if (ir == nullptr) - this->ir = kernel->ir; + this->ir = kernel->ir.get(); initialize_context(); context_ty = get_runtime_type("Context"); diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 109088aa1d045..5fb2fdf644397 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -297,6 +297,14 @@ IRNode *Stmt::get_ir_root() { return dynamic_cast(block); } +Kernel *Stmt::get_kernel() const { + if (parent) { + return parent->get_kernel(); + } else { + return nullptr; + } +} + std::vector Stmt::get_operands() const { std::vector ret; for (int i = 0; i < num_operands(); i++) { @@ -706,6 +714,17 @@ Stmt *Block::mask() { } } +Kernel *Block::get_kernel() const { + Block *parent = this->parent; + if (parent == nullptr) { + return kernel; + } + while (parent->parent) { + parent = parent->parent; + } + return parent->kernel; +} + void Block::set_statements(VecStatement &&stmts) { statements.clear(); for (int i = 0; i < (int)stmts.size(); i++) { diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 57aa1d0a88f59..0c052302a117a 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -245,6 +245,9 @@ class IRNode { virtual void accept(IRVisitor *visitor) { TI_NOT_IMPLEMENTED } + virtual Kernel *get_kernel() const { + return nullptr; + } virtual ~IRNode() = default; template @@ -553,6 +556,8 @@ class Stmt : public IRNode { IRNode *get_ir_root(); + Kernel *get_kernel() const override; + virtual void repeat(int factor) { ret_type.width *= factor; } @@ -809,6 +814,7 @@ class Block : public IRNode { std::vector> statements, trash_bin; Stmt *mask_var; std::vector stop_gradients; + Kernel *kernel; // Only used in frontend. Stores LoopIndexStmt or BinaryOpStmt for loop // variables, and AllocaStmt for other variables. @@ -817,6 +823,7 @@ class Block : public IRNode { Block() { mask_var = nullptr; parent = nullptr; + kernel = nullptr; } bool has_container_statements(); @@ -838,6 +845,7 @@ class Block : public IRNode { bool replace_usages = true); Stmt *lookup_var(const Identifier &ident) const; Stmt *mask(); + Kernel *get_kernel() const override; Stmt *back() const { return statements.back().get(); diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 48a69e76bb047..7977e7431eaf1 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -14,16 +14,15 @@ void re_id(IRNode *root); void flag_access(IRNode *root); void die(IRNode *root); void simplify(IRNode *root, Kernel *kernel = nullptr); -bool alg_simp(IRNode *root, const CompileConfig &config); + +bool alg_simp(IRNode *root); void whole_kernel_cse(IRNode *root); void variable_optimization(IRNode *root, bool after_lower_access); void extract_constant(IRNode *root); -void full_simplify(IRNode *root, - const CompileConfig &config, - Kernel *kernel = nullptr); +void full_simplify(IRNode *root, Kernel *kernel = nullptr); void print(IRNode *root, std::string *output = nullptr); void lower(IRNode *root); -void typecheck(IRNode *root, Kernel *kernel = nullptr); +void typecheck(IRNode *root); void loop_vectorize(IRNode *root); void slp_vectorize(IRNode *root); void vector_split(IRNode *root, int max_width, bool serial_schedule); diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 459bb2ab4bdc3..600fa71906ec4 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -51,7 +51,7 @@ void ExecutionQueue::enqueue(KernelLaunchRecord &&ker) { flag_access(stmt); lower_access(stmt, true, kernel); flag_access(stmt); - full_simplify(stmt, kernel->program.config, kernel); + full_simplify(stmt, kernel); // analysis::verify(stmt); } auto func = CodeGenCPU(kernel, stmt).codegen(); @@ -108,7 +108,7 @@ ExecutionQueue::ExecutionQueue() void AsyncEngine::launch(Kernel *kernel) { if (!kernel->lowered) kernel->lower(false); - auto block = dynamic_cast(kernel->ir); + auto block = dynamic_cast(kernel->ir.get()); TI_ASSERT(block); auto &offloads = block->statements; for (std::size_t i = 0; i < offloads.size(); i++) { @@ -266,7 +266,7 @@ bool AsyncEngine::fuse() { irpass::fix_block_parents(task_a); auto kernel = task_queue[i].kernel; - irpass::full_simplify(task_a, kernel->program.config, kernel); + irpass::full_simplify(task_a, kernel); task_queue[i].h = hash(task_a); modified = true; diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index af391b7713b5e..2ca6a7de373f8 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -36,14 +36,14 @@ Kernel::Kernel(Program &program, is_evaluator = false; compiled = nullptr; taichi::lang::context = std::make_unique(); - ir_holder = taichi::lang::context->get_root(); - ir = ir_holder.get(); + ir = taichi::lang::context->get_root(); { CurrentKernelGuard _(program, this); program.start_function_definition(this); func(); program.end_function_definition(); + ir->as()->kernel = this; } arch = program.config.arch; @@ -74,7 +74,7 @@ void Kernel::lower(bool lower_access) { // TODO: is a "Lowerer" class necessary if (is_accessor && !config.print_accessor_ir) verbose = false; irpass::compile_to_offloads( - ir, config, /*vectorize*/ arch_is_cpu(arch), grad, + ir.get(), config, /*vectorize*/ arch_is_cpu(arch), grad, /*ad_use_stack*/ true, verbose, /*lower_global_access*/ lower_access); } else { TI_NOT_IMPLEMENTED diff --git a/taichi/program/kernel.h b/taichi/program/kernel.h index e9a714b92df1e..9ad53c98c8703 100644 --- a/taichi/program/kernel.h +++ b/taichi/program/kernel.h @@ -10,8 +10,7 @@ class Program; class Kernel { public: - std::unique_ptr ir_holder; - IRNode *ir; + std::unique_ptr ir; Program &program; FunctionType compiled; std::string name; diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index f663a6c006a05..51dc0c1137e0c 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -178,7 +178,8 @@ class AlgSimp : public BasicStmtVisitor { namespace irpass { -bool alg_simp(IRNode *root, const CompileConfig &config) { +bool alg_simp(IRNode *root) { + const auto &config = root->get_kernel()->program.config; return AlgSimp::run(root, config.fast_math); } diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 1b1c15c328749..c778d76d1fe3c 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -53,9 +53,9 @@ void compile_to_offloads(IRNode *ir, if (grad) { irpass::demote_atomics(ir); - irpass::full_simplify(ir, config); + irpass::full_simplify(ir); irpass::make_adjoint(ir, ad_use_stack); - irpass::full_simplify(ir, config); + irpass::full_simplify(ir); print("Adjoint"); irpass::analysis::verify(ir); } @@ -91,7 +91,7 @@ void compile_to_offloads(IRNode *ir, irpass::analysis::verify(ir); } - irpass::full_simplify(ir, config); + irpass::full_simplify(ir); print("Simplified II"); irpass::analysis::verify(ir); @@ -122,7 +122,7 @@ void compile_to_offloads(IRNode *ir, irpass::variable_optimization(ir, true); print("Store forwarded II"); - irpass::full_simplify(ir, config); + irpass::full_simplify(ir); print("Simplified III"); // Final field registration correctness & type checking diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index bfd337b17c5a1..3d8858b217048 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -113,13 +113,14 @@ class ConstantFold : public BasicStmtVisitor { rhs.dt, true}; auto *ker = get_jit_evaluator_kernel(id); - auto &ctx = get_current_program().get_context(); + auto ¤t_program = stmt->get_kernel()->program; + auto &ctx = current_program.get_context(); ContextArgSaveGuard _( ctx); // save input args, prevent override current kernel ctx.set_arg(0, lhs.val_i64); ctx.set_arg(1, rhs.val_i64); (*ker)(); - ret.val_i64 = get_current_program().fetch_result(0); + ret.val_i64 = current_program.fetch_result(0); return true; } @@ -135,12 +136,13 @@ class ConstantFold : public BasicStmtVisitor { stmt->cast_type, false}; auto *ker = get_jit_evaluator_kernel(id); - auto &ctx = get_current_program().get_context(); + auto ¤t_program = stmt->get_kernel()->program; + auto &ctx = current_program.get_context(); ContextArgSaveGuard _( ctx); // save input args, prevent override current kernel ctx.set_arg(0, operand.val_i64); (*ker)(); - ret.val_i64 = get_current_program().fetch_result(0); + ret.val_i64 = current_program.fetch_result(0); return true; } @@ -204,7 +206,8 @@ void constant_fold(IRNode *root) { // disable constant_fold when config.debug is turned on. // Discussion: // https://github.com/taichi-dev/taichi/pull/839#issuecomment-626107010 - if (get_current_program().config.debug) { + auto kernel = root->get_kernel(); + if (kernel && kernel->program.config.debug) { TI_TRACE("config.debug enabled, ignoring constant fold"); return; } diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index a4b4ca86ea71c..f0eb871001304 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -246,7 +246,7 @@ namespace irpass { void lower_access(IRNode *root, bool lower_atomic, Kernel *kernel) { LowerAccess::run(root, lower_atomic); - typecheck(root, kernel); + typecheck(root); } } // namespace irpass diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index aaa62e305670e..22259270eaa7e 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -807,7 +807,7 @@ class BasicBlockSimplify : public IRVisitor { stmt->insert_before_me(std::move(sum)); stmt->parent->erase(stmt); // get types of adds and muls - irpass::typecheck(stmt->parent, kernel); + irpass::typecheck(stmt->parent); throw IRModified(); } @@ -1160,10 +1160,10 @@ void simplify(IRNode *root, Kernel *kernel) { } } -void full_simplify(IRNode *root, const CompileConfig &config, Kernel *kernel) { +void full_simplify(IRNode *root, Kernel *kernel) { constant_fold(root); if (advanced_optimization) { - alg_simp(root, config); + alg_simp(root); die(root); whole_kernel_cse(root); } diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index c0f6d4efe5318..5e92bd80870c2 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -17,10 +17,11 @@ class TypeCheck : public IRVisitor { CompileConfig config; public: - TypeCheck(Kernel *kernel) : kernel(kernel) { - // TODO: remove dependency on get_current_program here - if (current_program != nullptr) - config = get_current_program().config; + TypeCheck(IRNode *root) { + kernel = root->get_kernel(); + if (kernel != nullptr) { + config = kernel->program.config; + } allow_undefined_visitor = true; } @@ -316,7 +317,7 @@ class TypeCheck : public IRVisitor { void visit(ArgLoadStmt *stmt) { Kernel *current_kernel = kernel; if (current_kernel == nullptr) { - current_kernel = &get_current_program().get_current_kernel(); + current_kernel = stmt->get_kernel(); } auto &args = current_kernel->args; TI_ASSERT(0 <= stmt->arg_id && stmt->arg_id < args.size()); @@ -326,7 +327,7 @@ class TypeCheck : public IRVisitor { void visit(KernelReturnStmt *stmt) { Kernel *current_kernel = kernel; if (current_kernel == nullptr) { - current_kernel = &get_current_program().get_current_kernel(); + current_kernel = stmt->get_kernel(); } auto &rets = current_kernel->rets; TI_ASSERT(rets.size() >= 1); @@ -416,9 +417,9 @@ class TypeCheck : public IRVisitor { namespace irpass { -void typecheck(IRNode *root, Kernel *kernel) { +void typecheck(IRNode *root) { analysis::check_fields_registered(root); - TypeCheck inst(kernel); + TypeCheck inst(root); root->accept(&inst); } diff --git a/tests/cpp/test_alg_simp.cpp b/tests/cpp/test_alg_simp.cpp index df5f59d3d55cb..ba45d2a03e790 100644 --- a/tests/cpp/test_alg_simp.cpp +++ b/tests/cpp/test_alg_simp.cpp @@ -11,6 +11,11 @@ TI_TEST("alg_simp") { auto block = std::make_unique(); + auto func = []() {}; + auto kernel = + std::make_unique(get_current_program(), func, "fake_kernel"); + block->kernel = kernel.get(); + auto global_load_addr = block->push_back(0, VectorType(1, DataType::i32)); auto global_load = block->push_back(global_load_addr); @@ -27,8 +32,8 @@ TI_TEST("alg_simp") { // irpass::print(block.get()); - irpass::alg_simp(block.get(), CompileConfig()); // should eliminate add - irpass::die(block.get()); // should eliminate zero + irpass::alg_simp(block.get()); // should eliminate add + irpass::die(block.get()); // should eliminate zero // irpass::print(block.get()); TI_CHECK(block->size() == 4); // two addresses, one load, one store @@ -41,6 +46,11 @@ TI_TEST("alg_simp") { auto block = std::make_unique(); + auto func = []() {}; + auto kernel = + std::make_unique(get_current_program(), func, "fake_kernel"); + block->kernel = kernel.get(); + auto global_load_addr = block->push_back(0, VectorType(1, DataType::f32)); auto global_load = block->push_back(global_load_addr); @@ -61,9 +71,8 @@ TI_TEST("alg_simp") { // irpass::print(block.get()); - irpass::alg_simp(block.get(), - CompileConfig()); // should eliminate mul, div, sub - irpass::die(block.get()); // should eliminate zero, one + irpass::alg_simp(block.get()); // should eliminate mul, div, sub + irpass::die(block.get()); // should eliminate zero, one // irpass::print(block.get()); @@ -75,6 +84,10 @@ TI_TEST("alg_simp") { TI_TEST_PROGRAM; auto block = std::make_unique(); + auto func = []() {}; + auto kernel = + std::make_unique(get_current_program(), func, "fake_kernel"); + block->kernel = kernel.get(); auto global_load_addr = block->push_back(0, VectorType(1, DataType::i32)); @@ -94,13 +107,15 @@ TI_TEST("alg_simp") { CompileConfig config_without_fast_math; config_without_fast_math.fast_math = false; - irpass::alg_simp(block.get(), - config_without_fast_math); // should eliminate mul, add - irpass::die(block.get()); // should eliminate zero, load + kernel->program.config = config_without_fast_math; + + irpass::alg_simp(block.get()); // should eliminate mul, add + irpass::die(block.get()); // should eliminate zero, load TI_CHECK(block->size() == 3); // one address, one one, one store block = std::make_unique(); + block->kernel = kernel.get(); global_load_addr = block->push_back(8, VectorType(1, DataType::f32)); @@ -117,16 +132,16 @@ TI_TEST("alg_simp") { TI_CHECK(block->size() == 10); irpass::constant_fold(block.get()); // should change 2 casts into const - irpass::alg_simp(block.get(), - config_without_fast_math); // should not eliminate - irpass::die(block.get()); // should eliminate 2 const + irpass::alg_simp(block.get()); // should not eliminate + irpass::die(block.get()); // should eliminate 2 const TI_CHECK(block->size() == 8); CompileConfig config_with_fast_math; config_with_fast_math.fast_math = true; - irpass::alg_simp(block.get(), - config_with_fast_math); // should eliminate mul, add - irpass::die(block.get()); // should eliminate zero, load + kernel->program.config = config_with_fast_math; + + irpass::alg_simp(block.get()); // should eliminate mul, add + irpass::die(block.get()); // should eliminate zero, load TI_CHECK(block->size() == 3); // one address, one one, one store } @@ -147,11 +162,15 @@ TI_TEST("alg_simp") { auto global_store = block->push_back(global_store_addr, and_result); + auto func = []() {}; + auto kernel = + std::make_unique(get_current_program(), func, "fake_kernel"); + block->kernel = kernel.get(); irpass::typecheck(block.get()); TI_CHECK(block->size() == 6); - irpass::alg_simp(block.get(), CompileConfig()); // should eliminate and - irpass::die(block.get()); // should eliminate zero + irpass::alg_simp(block.get()); // should eliminate and + irpass::die(block.get()); // should eliminate zero TI_CHECK(block->size() == 4); // two addresses, one load, one store TI_CHECK((*block)[0]->is()); diff --git a/tests/cpp/test_simplify.cpp b/tests/cpp/test_simplify.cpp index 66a00278e1eb1..e4d4085f7e2bd 100644 --- a/tests/cpp/test_simplify.cpp +++ b/tests/cpp/test_simplify.cpp @@ -12,6 +12,11 @@ TI_TEST("simplify") { auto block = std::make_unique(); + auto func = []() {}; + auto kernel = + std::make_unique(get_current_program(), func, "fake_kernel"); + block->kernel = kernel.get(); + auto get_root = block->push_back(); auto linearized_empty = block->push_back( std::vector(), std::vector()); @@ -34,7 +39,7 @@ TI_TEST("simplify") { // TI_CHECK(block->size() == 11); // not required to check size here irpass::constant_fold(block.get()); - irpass::alg_simp(block.get(), CompileConfig()); + irpass::alg_simp(block.get()); irpass::die(block.get()); // should eliminate consts irpass::simplify(block.get()); if (advanced_optimization) {