Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] [ir] IR system refactorings #1058

Merged
merged 16 commits into from
May 26, 2020
2 changes: 1 addition & 1 deletion taichi/analysis/clone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class IRCloner : public IRVisitor {

static std::unique_ptr<IRNode> run(IRNode *root, Kernel *kernel) {
if (kernel == nullptr) {
kernel = &get_current_program().get_current_kernel();
kernel = root->get_kernel()->program.current_kernel;
TH3CHARLie marked this conversation as resolved.
Show resolved Hide resolved
}
std::unique_ptr<IRNode> new_root = root->clone();
IRCloner cloner(new_root.get());
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ CodeGen::CodeGen(Kernel *kernel,
FunctionType CodeGen::compile() {
auto &config = kernel_->program.config;
config.demote_dense_struct_fors = true;
irpass::compile_to_offloads(kernel_->ir, config,
irpass::compile_to_offloads(kernel_->ir.get(), config,
/*vectorize=*/false, kernel_->grad,
/*ad_use_stack=*/false, config.print_ir);

Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ FunctionType OpenglCodeGen::gen(void) {
}

void OpenglCodeGen::lower() {
auto ir = kernel_->ir;
auto ir = kernel_->ir.get();
auto &config = kernel_->program.config;
config.demote_dense_struct_fors = true;
irpass::compile_to_offloads(ir, config,
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ TLANG_NAMESPACE_BEGIN
KernelCodeGen::KernelCodeGen(Kernel *kernel, IRNode *ir)
: prog(&kernel->program), kernel(kernel), ir(ir) {
if (ir == nullptr)
this->ir = kernel->ir;
this->ir = kernel->ir.get();

auto num_stmts = irpass::analysis::count_statements(this->ir);
if (kernel->is_evaluator)
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ CodeGenLLVM::CodeGenLLVM(Kernel *kernel, IRNode *ir)
ir(ir),
prog(&kernel->program) {
if (ir == nullptr)
this->ir = kernel->ir;
this->ir = kernel->ir.get();
initialize_context();

context_ty = get_runtime_type("Context");
Expand Down
19 changes: 19 additions & 0 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,14 @@ IRNode *Stmt::get_ir_root() {
return dynamic_cast<IRNode *>(block);
}

Kernel *Stmt::get_kernel() const {
if (parent) {
return parent->get_kernel();
} else {
return nullptr;
}
}

std::vector<Stmt *> Stmt::get_operands() const {
std::vector<Stmt *> ret;
for (int i = 0; i < num_operands(); i++) {
Expand Down Expand Up @@ -707,6 +715,17 @@ Stmt *Block::mask() {
}
}

Kernel *Block::get_kernel() const {
Block *parent = this->parent;
if (parent == nullptr) {
return kernel;
}
while (parent->parent) {
parent = parent->parent;
}
return parent->kernel;
}

void Block::set_statements(VecStatement &&stmts) {
statements.clear();
for (int i = 0; i < (int)stmts.size(); i++) {
Expand Down
7 changes: 7 additions & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ class IRNode {
virtual void accept(IRVisitor *visitor) {
TI_NOT_IMPLEMENTED
}
virtual Kernel *get_kernel() const {
return nullptr;
}
virtual ~IRNode() = default;

template <typename T>
Expand Down Expand Up @@ -552,6 +555,8 @@ class Stmt : public IRNode {

IRNode *get_ir_root();

Kernel *get_kernel() const override;

virtual void repeat(int factor) {
ret_type.width *= factor;
}
Expand Down Expand Up @@ -808,6 +813,7 @@ class Block : public IRNode {
std::vector<std::unique_ptr<Stmt>> statements, trash_bin;
Stmt *mask_var;
std::vector<SNode *> stop_gradients;
Kernel *kernel;

// Only used in frontend. Stores LoopIndexStmt or BinaryOpStmt for loop
// variables, and AllocaStmt for other variables.
Expand Down Expand Up @@ -837,6 +843,7 @@ class Block : public IRNode {
bool replace_usages = true);
Stmt *lookup_var(const Identifier &ident) const;
Stmt *mask();
Kernel *get_kernel() const override;

Stmt *back() const {
return statements.back().get();
Expand Down
2 changes: 1 addition & 1 deletion taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ ExecutionQueue::ExecutionQueue()
void AsyncEngine::launch(Kernel *kernel) {
if (!kernel->lowered)
kernel->lower(false);
auto block = dynamic_cast<Block *>(kernel->ir);
auto block = dynamic_cast<Block *>(kernel->ir.get());
TI_ASSERT(block);
auto &offloads = block->statements;
for (std::size_t i = 0; i < offloads.size(); i++) {
Expand Down
6 changes: 3 additions & 3 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ Kernel::Kernel(Program &program,
is_evaluator = false;
compiled = nullptr;
taichi::lang::context = std::make_unique<FrontendContext>();
ir_holder = taichi::lang::context->get_root();
ir = ir_holder.get();
ir = taichi::lang::context->get_root();

{
CurrentKernelGuard _(program, this);
program.start_function_definition(this);
func();
program.end_function_definition();
dynamic_cast<Block *>(ir.get())->kernel = this;
}

arch = program.config.arch;
Expand Down Expand Up @@ -74,7 +74,7 @@ void Kernel::lower(bool lower_access) { // TODO: is a "Lowerer" class necessary
if (is_accessor && !config.print_accessor_ir)
verbose = false;
irpass::compile_to_offloads(
ir, config, /*vectorize*/ arch_is_cpu(arch), grad,
ir.get(), config, /*vectorize*/ arch_is_cpu(arch), grad,
/*ad_use_stack*/ true, verbose, /*lower_global_access*/ lower_access);
} else {
TI_NOT_IMPLEMENTED
Expand Down
3 changes: 1 addition & 2 deletions taichi/program/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ class Program;

class Kernel {
public:
std::unique_ptr<IRNode> ir_holder;
IRNode *ir;
std::unique_ptr<IRNode> ir;
Program &program;
FunctionType compiled;
std::string name;
Expand Down
12 changes: 7 additions & 5 deletions taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,14 @@ class ConstantFold : public BasicStmtVisitor {
rhs.dt,
true};
auto *ker = get_jit_evaluator_kernel(id);
auto &ctx = get_current_program().get_context();
auto &current_program = stmt->get_kernel()->program;
auto &ctx = current_program.get_context();
Copy link
Collaborator Author

@TH3CHARLie TH3CHARLie May 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This segment has some issues with it. Please see comment below

ContextArgSaveGuard _(
ctx); // save input args, prevent override current kernel
ctx.set_arg<int64_t>(0, lhs.val_i64);
ctx.set_arg<int64_t>(1, rhs.val_i64);
(*ker)();
ret.val_i64 = get_current_program().fetch_result<int64_t>(0);
ret.val_i64 = current_program.fetch_result<int64_t>(0);
return true;
}

Expand All @@ -135,12 +136,13 @@ class ConstantFold : public BasicStmtVisitor {
stmt->cast_type,
false};
auto *ker = get_jit_evaluator_kernel(id);
auto &ctx = get_current_program().get_context();
auto &current_program = stmt->get_kernel()->program;
auto &ctx = current_program.get_context();
ContextArgSaveGuard _(
ctx); // save input args, prevent override current kernel
ctx.set_arg<int64_t>(0, operand.val_i64);
(*ker)();
ret.val_i64 = get_current_program().fetch_result<int64_t>(0);
ret.val_i64 = current_program.fetch_result<int64_t>(0);
return true;
}

Expand Down Expand Up @@ -204,7 +206,7 @@ void constant_fold(IRNode *root) {
// disable constant_fold when config.debug is turned on.
// Discussion:
// https://github.com/taichi-dev/taichi/pull/839#issuecomment-626107010
if (get_current_program().config.debug) {
if (root->get_kernel()->program.config.debug) {
TI_TRACE("config.debug enabled, ignoring constant fold");
return;
}
Expand Down
13 changes: 6 additions & 7 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ class TypeCheck : public IRVisitor {
CompileConfig config;

public:
TypeCheck(Kernel *kernel) : kernel(kernel) {
// TODO: remove dependency on get_current_program here
if (current_program != nullptr)
config = get_current_program().config;
TypeCheck(IRNode *root) {
kernel = root->get_kernel();
config = kernel->program.config;
allow_undefined_visitor = true;
}

Expand Down Expand Up @@ -316,7 +315,7 @@ class TypeCheck : public IRVisitor {
void visit(ArgLoadStmt *stmt) {
Kernel *current_kernel = kernel;
if (current_kernel == nullptr) {
current_kernel = &get_current_program().get_current_kernel();
current_kernel = stmt->get_kernel();
}
auto &args = current_kernel->args;
TI_ASSERT(0 <= stmt->arg_id && stmt->arg_id < args.size());
Expand All @@ -326,7 +325,7 @@ class TypeCheck : public IRVisitor {
void visit(KernelReturnStmt *stmt) {
Kernel *current_kernel = kernel;
if (current_kernel == nullptr) {
current_kernel = &get_current_program().get_current_kernel();
current_kernel = stmt->get_kernel();
}
auto &rets = current_kernel->rets;
TI_ASSERT(rets.size() >= 1);
Expand Down Expand Up @@ -418,7 +417,7 @@ namespace irpass {

void typecheck(IRNode *root, Kernel *kernel) {
TH3CHARLie marked this conversation as resolved.
Show resolved Hide resolved
analysis::check_fields_registered(root);
TypeCheck inst(kernel);
TypeCheck inst(root);
root->accept(&inst);
}

Expand Down