From 27850ff7fb510a8616035b496e66ca63ced6ff8c Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 01:12:35 +0800 Subject: [PATCH 01/15] [refactor] [ir] IR system refactorings part 1 --- taichi/analysis/clone.cpp | 2 +- taichi/backends/metal/codegen_metal.cpp | 2 +- taichi/backends/opengl/codegen_opengl.cpp | 2 +- taichi/codegen/codegen.cpp | 2 +- taichi/codegen/codegen_llvm.cpp | 2 +- taichi/ir/ir.cpp | 16 ++++++++++++++++ taichi/ir/ir.h | 7 +++++++ taichi/program/async_engine.cpp | 2 +- taichi/program/kernel.cpp | 6 +++--- taichi/program/kernel.h | 3 +-- taichi/transforms/constant_fold.cpp | 12 +++++++----- taichi/transforms/type_check.cpp | 6 +++--- 12 files changed, 43 insertions(+), 19 deletions(-) diff --git a/taichi/analysis/clone.cpp b/taichi/analysis/clone.cpp index 751b0086e4795..2415a04656f60 100644 --- a/taichi/analysis/clone.cpp +++ b/taichi/analysis/clone.cpp @@ -112,7 +112,7 @@ class IRCloner : public IRVisitor { static std::unique_ptr run(IRNode *root, Kernel *kernel) { if (kernel == nullptr) { - kernel = &get_current_program().get_current_kernel(); + kernel = root->get_kernel()->program.current_kernel; } std::unique_ptr new_root = root->clone(); IRCloner cloner(new_root.get()); diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index 34c399272d859..a56d95bb70ffe 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -1009,7 +1009,7 @@ CodeGen::CodeGen(Kernel *kernel, FunctionType CodeGen::compile() { auto &config = kernel_->program.config; config.demote_dense_struct_fors = true; - irpass::compile_to_offloads(kernel_->ir, config, + irpass::compile_to_offloads(kernel_->ir.get(), config, /*vectorize=*/false, kernel_->grad, /*ad_use_stack=*/false, config.print_ir); diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index 77540a28d6a7e..a6643505e4faf 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -702,7 +702,7 @@ FunctionType OpenglCodeGen::gen(void) { } void OpenglCodeGen::lower() { - auto ir = kernel_->ir; + auto ir = kernel_->ir.get(); auto &config = kernel_->program.config; config.demote_dense_struct_fors = true; irpass::compile_to_offloads(ir, config, diff --git a/taichi/codegen/codegen.cpp b/taichi/codegen/codegen.cpp index 11a44e7643ae6..a2846ddedee61 100644 --- a/taichi/codegen/codegen.cpp +++ b/taichi/codegen/codegen.cpp @@ -15,7 +15,7 @@ TLANG_NAMESPACE_BEGIN KernelCodeGen::KernelCodeGen(Kernel *kernel, IRNode *ir) : prog(&kernel->program), kernel(kernel), ir(ir) { if (ir == nullptr) - this->ir = kernel->ir; + this->ir = kernel->ir.get(); auto num_stmts = irpass::analysis::count_statements(this->ir); if (kernel->is_evaluator) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index ec04e8aa48fd3..263e41179eeb2 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -272,7 +272,7 @@ CodeGenLLVM::CodeGenLLVM(Kernel *kernel, IRNode *ir) ir(ir), prog(&kernel->program) { if (ir == nullptr) - this->ir = kernel->ir; + this->ir = kernel->ir.get(); initialize_context(); context_ty = get_runtime_type("Context"); diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index f734f838bdf4e..8ba37d10ce6cb 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -298,6 +298,14 @@ IRNode *Stmt::get_ir_root() { return dynamic_cast(block); } +Kernel* Stmt::get_kernel() const { + if (parent) { + return parent->get_kernel(); + } else { + return nullptr; + } +} + std::vector Stmt::get_operands() const { std::vector ret; for (int i = 0; i < num_operands(); i++) { @@ -707,6 +715,14 @@ Stmt *Block::mask() { } } +Kernel* Block::get_kernel() const { + Block *parent = this->parent; + while (parent->parent) { + parent = parent->parent; + } + return parent->kernel; +} + void Block::set_statements(VecStatement &&stmts) { statements.clear(); for (int i = 0; i < (int)stmts.size(); i++) { diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 48acaeceb8a7d..043c8513a5bbf 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -244,6 +244,9 @@ class IRNode { virtual void accept(IRVisitor *visitor) { TI_NOT_IMPLEMENTED } + virtual Kernel* get_kernel() const { + return nullptr; + } virtual ~IRNode() = default; template @@ -552,6 +555,8 @@ class Stmt : public IRNode { IRNode *get_ir_root(); + Kernel* get_kernel() const override; + virtual void repeat(int factor) { ret_type.width *= factor; } @@ -808,6 +813,7 @@ class Block : public IRNode { std::vector> statements, trash_bin; Stmt *mask_var; std::vector stop_gradients; + Kernel* kernel; // Only used in frontend. Stores LoopIndexStmt or BinaryOpStmt for loop // variables, and AllocaStmt for other variables. @@ -837,6 +843,7 @@ class Block : public IRNode { bool replace_usages = true); Stmt *lookup_var(const Identifier &ident) const; Stmt *mask(); + Kernel* get_kernel() const override; Stmt *back() const { return statements.back().get(); diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 459bb2ab4bdc3..aadeedc60a51d 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -108,7 +108,7 @@ ExecutionQueue::ExecutionQueue() void AsyncEngine::launch(Kernel *kernel) { if (!kernel->lowered) kernel->lower(false); - auto block = dynamic_cast(kernel->ir); + auto block = dynamic_cast(kernel->ir.get()); TI_ASSERT(block); auto &offloads = block->statements; for (std::size_t i = 0; i < offloads.size(); i++) { diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index af391b7713b5e..41f799453322a 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -36,14 +36,14 @@ Kernel::Kernel(Program &program, is_evaluator = false; compiled = nullptr; taichi::lang::context = std::make_unique(); - ir_holder = taichi::lang::context->get_root(); - ir = ir_holder.get(); + ir = taichi::lang::context->get_root(); { CurrentKernelGuard _(program, this); program.start_function_definition(this); func(); program.end_function_definition(); + dynamic_cast(ir.get())->kernel = this; } arch = program.config.arch; @@ -74,7 +74,7 @@ void Kernel::lower(bool lower_access) { // TODO: is a "Lowerer" class necessary if (is_accessor && !config.print_accessor_ir) verbose = false; irpass::compile_to_offloads( - ir, config, /*vectorize*/ arch_is_cpu(arch), grad, + ir.get(), config, /*vectorize*/ arch_is_cpu(arch), grad, /*ad_use_stack*/ true, verbose, /*lower_global_access*/ lower_access); } else { TI_NOT_IMPLEMENTED diff --git a/taichi/program/kernel.h b/taichi/program/kernel.h index e9a714b92df1e..9ad53c98c8703 100644 --- a/taichi/program/kernel.h +++ b/taichi/program/kernel.h @@ -10,8 +10,7 @@ class Program; class Kernel { public: - std::unique_ptr ir_holder; - IRNode *ir; + std::unique_ptr ir; Program &program; FunctionType compiled; std::string name; diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index bfd337b17c5a1..00ec883a99d58 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -113,13 +113,14 @@ class ConstantFold : public BasicStmtVisitor { rhs.dt, true}; auto *ker = get_jit_evaluator_kernel(id); - auto &ctx = get_current_program().get_context(); + auto ¤t_program = stmt->get_kernel()->program; + auto &ctx = current_program.get_context(); ContextArgSaveGuard _( ctx); // save input args, prevent override current kernel ctx.set_arg(0, lhs.val_i64); ctx.set_arg(1, rhs.val_i64); (*ker)(); - ret.val_i64 = get_current_program().fetch_result(0); + ret.val_i64 = current_program.fetch_result(0); return true; } @@ -135,12 +136,13 @@ class ConstantFold : public BasicStmtVisitor { stmt->cast_type, false}; auto *ker = get_jit_evaluator_kernel(id); - auto &ctx = get_current_program().get_context(); + auto ¤t_program = stmt->get_kernel()->program; + auto &ctx = current_program.get_context(); ContextArgSaveGuard _( ctx); // save input args, prevent override current kernel ctx.set_arg(0, operand.val_i64); (*ker)(); - ret.val_i64 = get_current_program().fetch_result(0); + ret.val_i64 = current_program.fetch_result(0); return true; } @@ -204,7 +206,7 @@ void constant_fold(IRNode *root) { // disable constant_fold when config.debug is turned on. // Discussion: // https://github.com/taichi-dev/taichi/pull/839#issuecomment-626107010 - if (get_current_program().config.debug) { + if (root->get_kernel()->program.config.debug) { TI_TRACE("config.debug enabled, ignoring constant fold"); return; } diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index c0f6d4efe5318..a5558261b53a9 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -20,7 +20,7 @@ class TypeCheck : public IRVisitor { TypeCheck(Kernel *kernel) : kernel(kernel) { // TODO: remove dependency on get_current_program here if (current_program != nullptr) - config = get_current_program().config; + config = kernel->program.config; allow_undefined_visitor = true; } @@ -316,7 +316,7 @@ class TypeCheck : public IRVisitor { void visit(ArgLoadStmt *stmt) { Kernel *current_kernel = kernel; if (current_kernel == nullptr) { - current_kernel = &get_current_program().get_current_kernel(); + current_kernel = stmt->get_kernel(); } auto &args = current_kernel->args; TI_ASSERT(0 <= stmt->arg_id && stmt->arg_id < args.size()); @@ -326,7 +326,7 @@ class TypeCheck : public IRVisitor { void visit(KernelReturnStmt *stmt) { Kernel *current_kernel = kernel; if (current_kernel == nullptr) { - current_kernel = &get_current_program().get_current_kernel(); + current_kernel = stmt->get_kernel(); } auto &rets = current_kernel->rets; TI_ASSERT(rets.size() >= 1); From 7185495376a730b900734e8f193cd0d2fb11cf65 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 01:19:13 +0800 Subject: [PATCH 02/15] format --- taichi/ir/ir.cpp | 4 ++-- taichi/ir/ir.h | 8 ++++---- taichi/program/kernel.cpp | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 8ba37d10ce6cb..402af2f607998 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -298,7 +298,7 @@ IRNode *Stmt::get_ir_root() { return dynamic_cast(block); } -Kernel* Stmt::get_kernel() const { +Kernel *Stmt::get_kernel() const { if (parent) { return parent->get_kernel(); } else { @@ -715,7 +715,7 @@ Stmt *Block::mask() { } } -Kernel* Block::get_kernel() const { +Kernel *Block::get_kernel() const { Block *parent = this->parent; while (parent->parent) { parent = parent->parent; diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 043c8513a5bbf..11bd49167dde2 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -244,7 +244,7 @@ class IRNode { virtual void accept(IRVisitor *visitor) { TI_NOT_IMPLEMENTED } - virtual Kernel* get_kernel() const { + virtual Kernel *get_kernel() const { return nullptr; } virtual ~IRNode() = default; @@ -555,7 +555,7 @@ class Stmt : public IRNode { IRNode *get_ir_root(); - Kernel* get_kernel() const override; + Kernel *get_kernel() const override; virtual void repeat(int factor) { ret_type.width *= factor; @@ -813,7 +813,7 @@ class Block : public IRNode { std::vector> statements, trash_bin; Stmt *mask_var; std::vector stop_gradients; - Kernel* kernel; + Kernel *kernel; // Only used in frontend. Stores LoopIndexStmt or BinaryOpStmt for loop // variables, and AllocaStmt for other variables. @@ -843,7 +843,7 @@ class Block : public IRNode { bool replace_usages = true); Stmt *lookup_var(const Identifier &ident) const; Stmt *mask(); - Kernel* get_kernel() const override; + Kernel *get_kernel() const override; Stmt *back() const { return statements.back().get(); diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 41f799453322a..9d4a51eadc61a 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -43,7 +43,7 @@ Kernel::Kernel(Program &program, program.start_function_definition(this); func(); program.end_function_definition(); - dynamic_cast(ir.get())->kernel = this; + dynamic_cast(ir.get())->kernel = this; } arch = program.config.arch; From e688aa4fd7d5853857c52c412360f9f18b45b976 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 02:46:01 +0800 Subject: [PATCH 03/15] fix segfault --- taichi/ir/ir.cpp | 3 +++ taichi/transforms/type_check.cpp | 9 ++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 402af2f607998..1057f4c528cc6 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -717,6 +717,9 @@ Stmt *Block::mask() { Kernel *Block::get_kernel() const { Block *parent = this->parent; + if (parent == nullptr) { + return kernel; + } while (parent->parent) { parent = parent->parent; } diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index a5558261b53a9..e4f6d4ba07773 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -17,10 +17,9 @@ class TypeCheck : public IRVisitor { CompileConfig config; public: - TypeCheck(Kernel *kernel) : kernel(kernel) { - // TODO: remove dependency on get_current_program here - if (current_program != nullptr) - config = kernel->program.config; + TypeCheck(IRNode *root) { + kernel = root->get_kernel(); + config = kernel->program.config; allow_undefined_visitor = true; } @@ -418,7 +417,7 @@ namespace irpass { void typecheck(IRNode *root, Kernel *kernel) { analysis::check_fields_registered(root); - TypeCheck inst(kernel); + TypeCheck inst(root); root->accept(&inst); } From 34a85032bfe3d0da993e54544b938ab90e5677e6 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 12:25:38 +0800 Subject: [PATCH 04/15] refactor typecheck signature and misc --- taichi/analysis/clone.cpp | 2 +- taichi/ir/transforms.h | 2 +- taichi/transforms/type_check.cpp | 6 ++++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/taichi/analysis/clone.cpp b/taichi/analysis/clone.cpp index 2415a04656f60..bc5af9dd55062 100644 --- a/taichi/analysis/clone.cpp +++ b/taichi/analysis/clone.cpp @@ -112,7 +112,7 @@ class IRCloner : public IRVisitor { static std::unique_ptr run(IRNode *root, Kernel *kernel) { if (kernel == nullptr) { - kernel = root->get_kernel()->program.current_kernel; + kernel = root->get_kernel(); } std::unique_ptr new_root = root->clone(); IRCloner cloner(new_root.get()); diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 80577cbd964b1..71a1cd1be1a95 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -23,7 +23,7 @@ void full_simplify(IRNode *root, Kernel *kernel = nullptr); void print(IRNode *root, std::string *output = nullptr); void lower(IRNode *root); -void typecheck(IRNode *root, Kernel *kernel = nullptr); +void typecheck(IRNode *root); void loop_vectorize(IRNode *root); void slp_vectorize(IRNode *root); void vector_split(IRNode *root, int max_width, bool serial_schedule); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index e4f6d4ba07773..5e92bd80870c2 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -19,7 +19,9 @@ class TypeCheck : public IRVisitor { public: TypeCheck(IRNode *root) { kernel = root->get_kernel(); - config = kernel->program.config; + if (kernel != nullptr) { + config = kernel->program.config; + } allow_undefined_visitor = true; } @@ -415,7 +417,7 @@ class TypeCheck : public IRVisitor { namespace irpass { -void typecheck(IRNode *root, Kernel *kernel) { +void typecheck(IRNode *root) { analysis::check_fields_registered(root); TypeCheck inst(root); root->accept(&inst); From c621207ada00c8b833d2794ab20ae4de115e7d5f Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 12:31:06 +0800 Subject: [PATCH 05/15] fix --- taichi/analysis/clone.cpp | 2 +- taichi/transforms/simplify.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/taichi/analysis/clone.cpp b/taichi/analysis/clone.cpp index bc5af9dd55062..5a787bd19338e 100644 --- a/taichi/analysis/clone.cpp +++ b/taichi/analysis/clone.cpp @@ -120,7 +120,7 @@ class IRCloner : public IRVisitor { root->accept(&cloner); cloner.phase = IRCloner::replace_operand; root->accept(&cloner); - irpass::typecheck(new_root.get(), kernel); + irpass::typecheck(new_root.get()); irpass::fix_block_parents(new_root.get()); return new_root; } diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index aaa62e305670e..4ada0a320642d 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -807,7 +807,7 @@ class BasicBlockSimplify : public IRVisitor { stmt->insert_before_me(std::move(sum)); stmt->parent->erase(stmt); // get types of adds and muls - irpass::typecheck(stmt->parent, kernel); + irpass::typecheck(stmt->parent); throw IRModified(); } From bddafd8d30731945300df3d6c15bb4a590a59b82 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 12:40:03 +0800 Subject: [PATCH 06/15] fix build --- taichi/transforms/lower_access.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index a4b4ca86ea71c..f0eb871001304 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -246,7 +246,7 @@ namespace irpass { void lower_access(IRNode *root, bool lower_atomic, Kernel *kernel) { LowerAccess::run(root, lower_atomic); - typecheck(root, kernel); + typecheck(root); } } // namespace irpass From 00a5251b52f1508748a1eb6d04ac6ab61e0a436a Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 12:46:07 +0800 Subject: [PATCH 07/15] use utility function as --- taichi/program/kernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 9d4a51eadc61a..2ca6a7de373f8 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -43,7 +43,7 @@ Kernel::Kernel(Program &program, program.start_function_definition(this); func(); program.end_function_definition(); - dynamic_cast(ir.get())->kernel = this; + ir->as()->kernel = this; } arch = program.config.arch; From c72a470697962114d9e7ac794f457023ece904c2 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 13:14:37 +0800 Subject: [PATCH 08/15] fix init --- taichi/ir/ir.h | 1 + 1 file changed, 1 insertion(+) diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 11bd49167dde2..e0136bfacdaee 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -822,6 +822,7 @@ class Block : public IRNode { Block() { mask_var = nullptr; parent = nullptr; + kernel = nullptr; } bool has_container_statements(); From 113fc634ca91cedfcb28289c4d83fef73ad9fc07 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 13:52:01 +0800 Subject: [PATCH 09/15] fix segfault to pass CI --- taichi/transforms/constant_fold.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index 00ec883a99d58..4e962f7e24aa8 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -113,7 +113,9 @@ class ConstantFold : public BasicStmtVisitor { rhs.dt, true}; auto *ker = get_jit_evaluator_kernel(id); - auto ¤t_program = stmt->get_kernel()->program; + // TODO: the kernel->program approach + // auto ¤t_program = stmt->get_kernel()->program; + auto& current_program = get_current_program(); auto &ctx = current_program.get_context(); ContextArgSaveGuard _( ctx); // save input args, prevent override current kernel @@ -136,7 +138,9 @@ class ConstantFold : public BasicStmtVisitor { stmt->cast_type, false}; auto *ker = get_jit_evaluator_kernel(id); - auto ¤t_program = stmt->get_kernel()->program; + // TODO: the kernel->program approach + // auto ¤t_program = stmt->get_kernel()->program; + auto& current_program = get_current_program(); auto &ctx = current_program.get_context(); ContextArgSaveGuard _( ctx); // save input args, prevent override current kernel @@ -206,7 +210,8 @@ void constant_fold(IRNode *root) { // disable constant_fold when config.debug is turned on. // Discussion: // https://github.com/taichi-dev/taichi/pull/839#issuecomment-626107010 - if (root->get_kernel()->program.config.debug) { + auto kernel = root->get_kernel(); + if (kernel && kernel->program.config.debug) { TI_TRACE("config.debug enabled, ignoring constant fold"); return; } From 3278a674905a65d4d38389b72b57fc97c6d63435 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 13:54:38 +0800 Subject: [PATCH 10/15] format again --- taichi/transforms/constant_fold.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index 4e962f7e24aa8..e2322c36bea59 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -115,7 +115,7 @@ class ConstantFold : public BasicStmtVisitor { auto *ker = get_jit_evaluator_kernel(id); // TODO: the kernel->program approach // auto ¤t_program = stmt->get_kernel()->program; - auto& current_program = get_current_program(); + auto ¤t_program = get_current_program(); auto &ctx = current_program.get_context(); ContextArgSaveGuard _( ctx); // save input args, prevent override current kernel @@ -140,7 +140,7 @@ class ConstantFold : public BasicStmtVisitor { auto *ker = get_jit_evaluator_kernel(id); // TODO: the kernel->program approach // auto ¤t_program = stmt->get_kernel()->program; - auto& current_program = get_current_program(); + auto ¤t_program = get_current_program(); auto &ctx = current_program.get_context(); ContextArgSaveGuard _( ctx); // save input args, prevent override current kernel From 55161334ae3b426b730a1f648749ef1fc4c7a948 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 18:44:29 +0800 Subject: [PATCH 11/15] retrigger CI From 0f6c6316a4a1cb2f0b31763b8351ca2b7869925c Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 21:38:53 +0800 Subject: [PATCH 12/15] modify test with fake kernel --- taichi/transforms/constant_fold.cpp | 8 ++------ tests/cpp/test_alg_simp.cpp | 15 +++++++++++++++ tests/cpp/test_simplify.cpp | 4 ++++ 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index e2322c36bea59..3d8858b217048 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -113,9 +113,7 @@ class ConstantFold : public BasicStmtVisitor { rhs.dt, true}; auto *ker = get_jit_evaluator_kernel(id); - // TODO: the kernel->program approach - // auto ¤t_program = stmt->get_kernel()->program; - auto ¤t_program = get_current_program(); + auto ¤t_program = stmt->get_kernel()->program; auto &ctx = current_program.get_context(); ContextArgSaveGuard _( ctx); // save input args, prevent override current kernel @@ -138,9 +136,7 @@ class ConstantFold : public BasicStmtVisitor { stmt->cast_type, false}; auto *ker = get_jit_evaluator_kernel(id); - // TODO: the kernel->program approach - // auto ¤t_program = stmt->get_kernel()->program; - auto ¤t_program = get_current_program(); + auto ¤t_program = stmt->get_kernel()->program; auto &ctx = current_program.get_context(); ContextArgSaveGuard _( ctx); // save input args, prevent override current kernel diff --git a/tests/cpp/test_alg_simp.cpp b/tests/cpp/test_alg_simp.cpp index df5f59d3d55cb..654c8444fa914 100644 --- a/tests/cpp/test_alg_simp.cpp +++ b/tests/cpp/test_alg_simp.cpp @@ -11,6 +11,10 @@ TI_TEST("alg_simp") { auto block = std::make_unique(); + auto func = [](){}; + auto kernel = std::make_unique(get_current_program(), func, "fake_kernel"); + block->kernel = kernel.get(); + auto global_load_addr = block->push_back(0, VectorType(1, DataType::i32)); auto global_load = block->push_back(global_load_addr); @@ -41,6 +45,10 @@ TI_TEST("alg_simp") { auto block = std::make_unique(); + auto func = [](){}; + auto kernel = std::make_unique(get_current_program(), func, "fake_kernel"); + block->kernel = kernel.get(); + auto global_load_addr = block->push_back(0, VectorType(1, DataType::f32)); auto global_load = block->push_back(global_load_addr); @@ -75,6 +83,9 @@ TI_TEST("alg_simp") { TI_TEST_PROGRAM; auto block = std::make_unique(); + auto func = [](){}; + auto kernel = std::make_unique(get_current_program(), func, "fake_kernel"); + block->kernel = kernel.get(); auto global_load_addr = block->push_back(0, VectorType(1, DataType::i32)); @@ -101,6 +112,7 @@ TI_TEST("alg_simp") { TI_CHECK(block->size() == 3); // one address, one one, one store block = std::make_unique(); + block->kernel = kernel.get(); global_load_addr = block->push_back(8, VectorType(1, DataType::f32)); @@ -147,6 +159,9 @@ TI_TEST("alg_simp") { auto global_store = block->push_back(global_store_addr, and_result); + auto func = [](){}; + auto kernel = std::make_unique(get_current_program(), func, "fake_kernel"); + block->kernel = kernel.get(); irpass::typecheck(block.get()); TI_CHECK(block->size() == 6); diff --git a/tests/cpp/test_simplify.cpp b/tests/cpp/test_simplify.cpp index 66a00278e1eb1..aa2c1a9fe9d6b 100644 --- a/tests/cpp/test_simplify.cpp +++ b/tests/cpp/test_simplify.cpp @@ -12,6 +12,10 @@ TI_TEST("simplify") { auto block = std::make_unique(); + auto func = [](){}; + auto kernel = std::make_unique(get_current_program(), func, "fake_kernel"); + block->kernel = kernel.get(); + auto get_root = block->push_back(); auto linearized_empty = block->push_back( std::vector(), std::vector()); From ea8662840a205cae602d40dbc1201a518e732fd6 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 21:59:44 +0800 Subject: [PATCH 13/15] format --- tests/cpp/test_alg_simp.cpp | 20 ++++++++++++-------- tests/cpp/test_simplify.cpp | 5 +++-- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/cpp/test_alg_simp.cpp b/tests/cpp/test_alg_simp.cpp index 654c8444fa914..fab600fec778d 100644 --- a/tests/cpp/test_alg_simp.cpp +++ b/tests/cpp/test_alg_simp.cpp @@ -11,8 +11,9 @@ TI_TEST("alg_simp") { auto block = std::make_unique(); - auto func = [](){}; - auto kernel = std::make_unique(get_current_program(), func, "fake_kernel"); + auto func = []() {}; + auto kernel = + std::make_unique(get_current_program(), func, "fake_kernel"); block->kernel = kernel.get(); auto global_load_addr = @@ -45,8 +46,9 @@ TI_TEST("alg_simp") { auto block = std::make_unique(); - auto func = [](){}; - auto kernel = std::make_unique(get_current_program(), func, "fake_kernel"); + auto func = []() {}; + auto kernel = + std::make_unique(get_current_program(), func, "fake_kernel"); block->kernel = kernel.get(); auto global_load_addr = @@ -83,8 +85,9 @@ TI_TEST("alg_simp") { TI_TEST_PROGRAM; auto block = std::make_unique(); - auto func = [](){}; - auto kernel = std::make_unique(get_current_program(), func, "fake_kernel"); + auto func = []() {}; + auto kernel = + std::make_unique(get_current_program(), func, "fake_kernel"); block->kernel = kernel.get(); auto global_load_addr = @@ -159,8 +162,9 @@ TI_TEST("alg_simp") { auto global_store = block->push_back(global_store_addr, and_result); - auto func = [](){}; - auto kernel = std::make_unique(get_current_program(), func, "fake_kernel"); + auto func = []() {}; + auto kernel = + std::make_unique(get_current_program(), func, "fake_kernel"); block->kernel = kernel.get(); irpass::typecheck(block.get()); TI_CHECK(block->size() == 6); diff --git a/tests/cpp/test_simplify.cpp b/tests/cpp/test_simplify.cpp index aa2c1a9fe9d6b..4170e18d31c3d 100644 --- a/tests/cpp/test_simplify.cpp +++ b/tests/cpp/test_simplify.cpp @@ -12,8 +12,9 @@ TI_TEST("simplify") { auto block = std::make_unique(); - auto func = [](){}; - auto kernel = std::make_unique(get_current_program(), func, "fake_kernel"); + auto func = []() {}; + auto kernel = + std::make_unique(get_current_program(), func, "fake_kernel"); block->kernel = kernel.get(); auto get_root = block->push_back(); From 45282e15b911e51fa97ae6c048263f6468a4268a Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 22:16:03 +0800 Subject: [PATCH 14/15] remove CompileConfig from some transform passes --- taichi/ir/transforms.h | 6 ++--- taichi/program/async_engine.cpp | 4 +-- taichi/transforms/alg_simp.cpp | 3 ++- taichi/transforms/compile_to_offloads.cpp | 8 +++--- taichi/transforms/simplify.cpp | 4 +-- tests/cpp/test_alg_simp.cpp | 32 +++++++++++------------ tests/cpp/test_simplify.cpp | 2 +- 7 files changed, 29 insertions(+), 30 deletions(-) diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 71a1cd1be1a95..0aa7feccd7740 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -14,13 +14,11 @@ void re_id(IRNode *root); void flag_access(IRNode *root); void die(IRNode *root); void simplify(IRNode *root, Kernel *kernel = nullptr); -void alg_simp(IRNode *root, const CompileConfig &config); +void alg_simp(IRNode *root); void whole_kernel_cse(IRNode *root); void variable_optimization(IRNode *root, bool after_lower_access); void extract_constant(IRNode *root); -void full_simplify(IRNode *root, - const CompileConfig &config, - Kernel *kernel = nullptr); +void full_simplify(IRNode *root, Kernel *kernel = nullptr); void print(IRNode *root, std::string *output = nullptr); void lower(IRNode *root); void typecheck(IRNode *root); diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index aadeedc60a51d..600fa71906ec4 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -51,7 +51,7 @@ void ExecutionQueue::enqueue(KernelLaunchRecord &&ker) { flag_access(stmt); lower_access(stmt, true, kernel); flag_access(stmt); - full_simplify(stmt, kernel->program.config, kernel); + full_simplify(stmt, kernel); // analysis::verify(stmt); } auto func = CodeGenCPU(kernel, stmt).codegen(); @@ -266,7 +266,7 @@ bool AsyncEngine::fuse() { irpass::fix_block_parents(task_a); auto kernel = task_queue[i].kernel; - irpass::full_simplify(task_a, kernel->program.config, kernel); + irpass::full_simplify(task_a, kernel); task_queue[i].h = hash(task_a); modified = true; diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 41a53aab9c52c..78770c412f423 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -178,7 +178,8 @@ class AlgSimp : public BasicStmtVisitor { namespace irpass { -void alg_simp(IRNode *root, const CompileConfig &config) { +void alg_simp(IRNode *root) { + auto config = root->get_kernel()->program.config; return AlgSimp::run(root, config.fast_math); } diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 1b1c15c328749..c778d76d1fe3c 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -53,9 +53,9 @@ void compile_to_offloads(IRNode *ir, if (grad) { irpass::demote_atomics(ir); - irpass::full_simplify(ir, config); + irpass::full_simplify(ir); irpass::make_adjoint(ir, ad_use_stack); - irpass::full_simplify(ir, config); + irpass::full_simplify(ir); print("Adjoint"); irpass::analysis::verify(ir); } @@ -91,7 +91,7 @@ void compile_to_offloads(IRNode *ir, irpass::analysis::verify(ir); } - irpass::full_simplify(ir, config); + irpass::full_simplify(ir); print("Simplified II"); irpass::analysis::verify(ir); @@ -122,7 +122,7 @@ void compile_to_offloads(IRNode *ir, irpass::variable_optimization(ir, true); print("Store forwarded II"); - irpass::full_simplify(ir, config); + irpass::full_simplify(ir); print("Simplified III"); // Final field registration correctness & type checking diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 4ada0a320642d..22259270eaa7e 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -1160,10 +1160,10 @@ void simplify(IRNode *root, Kernel *kernel) { } } -void full_simplify(IRNode *root, const CompileConfig &config, Kernel *kernel) { +void full_simplify(IRNode *root, Kernel *kernel) { constant_fold(root); if (advanced_optimization) { - alg_simp(root, config); + alg_simp(root); die(root); whole_kernel_cse(root); } diff --git a/tests/cpp/test_alg_simp.cpp b/tests/cpp/test_alg_simp.cpp index fab600fec778d..ba45d2a03e790 100644 --- a/tests/cpp/test_alg_simp.cpp +++ b/tests/cpp/test_alg_simp.cpp @@ -32,8 +32,8 @@ TI_TEST("alg_simp") { // irpass::print(block.get()); - irpass::alg_simp(block.get(), CompileConfig()); // should eliminate add - irpass::die(block.get()); // should eliminate zero + irpass::alg_simp(block.get()); // should eliminate add + irpass::die(block.get()); // should eliminate zero // irpass::print(block.get()); TI_CHECK(block->size() == 4); // two addresses, one load, one store @@ -71,9 +71,8 @@ TI_TEST("alg_simp") { // irpass::print(block.get()); - irpass::alg_simp(block.get(), - CompileConfig()); // should eliminate mul, div, sub - irpass::die(block.get()); // should eliminate zero, one + irpass::alg_simp(block.get()); // should eliminate mul, div, sub + irpass::die(block.get()); // should eliminate zero, one // irpass::print(block.get()); @@ -108,9 +107,10 @@ TI_TEST("alg_simp") { CompileConfig config_without_fast_math; config_without_fast_math.fast_math = false; - irpass::alg_simp(block.get(), - config_without_fast_math); // should eliminate mul, add - irpass::die(block.get()); // should eliminate zero, load + kernel->program.config = config_without_fast_math; + + irpass::alg_simp(block.get()); // should eliminate mul, add + irpass::die(block.get()); // should eliminate zero, load TI_CHECK(block->size() == 3); // one address, one one, one store @@ -132,16 +132,16 @@ TI_TEST("alg_simp") { TI_CHECK(block->size() == 10); irpass::constant_fold(block.get()); // should change 2 casts into const - irpass::alg_simp(block.get(), - config_without_fast_math); // should not eliminate - irpass::die(block.get()); // should eliminate 2 const + irpass::alg_simp(block.get()); // should not eliminate + irpass::die(block.get()); // should eliminate 2 const TI_CHECK(block->size() == 8); CompileConfig config_with_fast_math; config_with_fast_math.fast_math = true; - irpass::alg_simp(block.get(), - config_with_fast_math); // should eliminate mul, add - irpass::die(block.get()); // should eliminate zero, load + kernel->program.config = config_with_fast_math; + + irpass::alg_simp(block.get()); // should eliminate mul, add + irpass::die(block.get()); // should eliminate zero, load TI_CHECK(block->size() == 3); // one address, one one, one store } @@ -169,8 +169,8 @@ TI_TEST("alg_simp") { irpass::typecheck(block.get()); TI_CHECK(block->size() == 6); - irpass::alg_simp(block.get(), CompileConfig()); // should eliminate and - irpass::die(block.get()); // should eliminate zero + irpass::alg_simp(block.get()); // should eliminate and + irpass::die(block.get()); // should eliminate zero TI_CHECK(block->size() == 4); // two addresses, one load, one store TI_CHECK((*block)[0]->is()); diff --git a/tests/cpp/test_simplify.cpp b/tests/cpp/test_simplify.cpp index 4170e18d31c3d..e4d4085f7e2bd 100644 --- a/tests/cpp/test_simplify.cpp +++ b/tests/cpp/test_simplify.cpp @@ -39,7 +39,7 @@ TI_TEST("simplify") { // TI_CHECK(block->size() == 11); // not required to check size here irpass::constant_fold(block.get()); - irpass::alg_simp(block.get(), CompileConfig()); + irpass::alg_simp(block.get()); irpass::die(block.get()); // should eliminate consts irpass::simplify(block.get()); if (advanced_optimization) { From b4f0e4a8ab20018088beb7d7ff50f1a2138c23e4 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 26 May 2020 22:53:00 +0800 Subject: [PATCH 15/15] Update taichi/transforms/alg_simp.cpp Co-authored-by: xumingkuan --- taichi/transforms/alg_simp.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 178007e5a7b1d..51dc0c1137e0c 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -179,7 +179,7 @@ class AlgSimp : public BasicStmtVisitor { namespace irpass { bool alg_simp(IRNode *root) { - auto config = root->get_kernel()->program.config; + const auto &config = root->get_kernel()->program.config; return AlgSimp::run(root, config.fast_math); }