Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] [ir] IR system refactorings #1058

Merged
merged 16 commits into from
May 26, 2020
4 changes: 2 additions & 2 deletions taichi/analysis/clone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ class IRCloner : public IRVisitor {

static std::unique_ptr<IRNode> run(IRNode *root, Kernel *kernel) {
if (kernel == nullptr) {
kernel = &get_current_program().get_current_kernel();
kernel = root->get_kernel();
}
std::unique_ptr<IRNode> new_root = root->clone();
IRCloner cloner(new_root.get());
cloner.phase = IRCloner::register_operand_map;
root->accept(&cloner);
cloner.phase = IRCloner::replace_operand;
root->accept(&cloner);
irpass::typecheck(new_root.get(), kernel);
irpass::typecheck(new_root.get());
irpass::fix_block_parents(new_root.get());
return new_root;
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ CodeGen::CodeGen(Kernel *kernel,
FunctionType CodeGen::compile() {
auto &config = kernel_->program.config;
config.demote_dense_struct_fors = true;
irpass::compile_to_offloads(kernel_->ir, config,
irpass::compile_to_offloads(kernel_->ir.get(), config,
/*vectorize=*/false, kernel_->grad,
/*ad_use_stack=*/false, config.print_ir);

Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ FunctionType OpenglCodeGen::gen(void) {
}

void OpenglCodeGen::lower() {
auto ir = kernel_->ir;
auto ir = kernel_->ir.get();
auto &config = kernel_->program.config;
config.demote_dense_struct_fors = true;
irpass::compile_to_offloads(ir, config,
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ TLANG_NAMESPACE_BEGIN
KernelCodeGen::KernelCodeGen(Kernel *kernel, IRNode *ir)
: prog(&kernel->program), kernel(kernel), ir(ir) {
if (ir == nullptr)
this->ir = kernel->ir;
this->ir = kernel->ir.get();

auto num_stmts = irpass::analysis::count_statements(this->ir);
if (kernel->is_evaluator)
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ CodeGenLLVM::CodeGenLLVM(Kernel *kernel, IRNode *ir)
ir(ir),
prog(&kernel->program) {
if (ir == nullptr)
this->ir = kernel->ir;
this->ir = kernel->ir.get();
initialize_context();

context_ty = get_runtime_type("Context");
Expand Down
19 changes: 19 additions & 0 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,14 @@ IRNode *Stmt::get_ir_root() {
return dynamic_cast<IRNode *>(block);
}

Kernel *Stmt::get_kernel() const {
if (parent) {
return parent->get_kernel();
} else {
return nullptr;
}
}

std::vector<Stmt *> Stmt::get_operands() const {
std::vector<Stmt *> ret;
for (int i = 0; i < num_operands(); i++) {
Expand Down Expand Up @@ -706,6 +714,17 @@ Stmt *Block::mask() {
}
}

Kernel *Block::get_kernel() const {
Block *parent = this->parent;
if (parent == nullptr) {
return kernel;
}
while (parent->parent) {
parent = parent->parent;
}
return parent->kernel;
}

void Block::set_statements(VecStatement &&stmts) {
statements.clear();
for (int i = 0; i < (int)stmts.size(); i++) {
Expand Down
8 changes: 8 additions & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ class IRNode {
virtual void accept(IRVisitor *visitor) {
TI_NOT_IMPLEMENTED
}
virtual Kernel *get_kernel() const {
return nullptr;
}
virtual ~IRNode() = default;

template <typename T>
Expand Down Expand Up @@ -553,6 +556,8 @@ class Stmt : public IRNode {

IRNode *get_ir_root();

Kernel *get_kernel() const override;

virtual void repeat(int factor) {
ret_type.width *= factor;
}
Expand Down Expand Up @@ -809,6 +814,7 @@ class Block : public IRNode {
std::vector<std::unique_ptr<Stmt>> statements, trash_bin;
Stmt *mask_var;
std::vector<SNode *> stop_gradients;
Kernel *kernel;

// Only used in frontend. Stores LoopIndexStmt or BinaryOpStmt for loop
// variables, and AllocaStmt for other variables.
Expand All @@ -817,6 +823,7 @@ class Block : public IRNode {
Block() {
mask_var = nullptr;
parent = nullptr;
kernel = nullptr;
}

bool has_container_statements();
Expand All @@ -838,6 +845,7 @@ class Block : public IRNode {
bool replace_usages = true);
Stmt *lookup_var(const Identifier &ident) const;
Stmt *mask();
Kernel *get_kernel() const override;

Stmt *back() const {
return statements.back().get();
Expand Down
9 changes: 4 additions & 5 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@ void re_id(IRNode *root);
void flag_access(IRNode *root);
void die(IRNode *root);
void simplify(IRNode *root, Kernel *kernel = nullptr);
bool alg_simp(IRNode *root, const CompileConfig &config);

bool alg_simp(IRNode *root);
void whole_kernel_cse(IRNode *root);
void variable_optimization(IRNode *root, bool after_lower_access);
void extract_constant(IRNode *root);
void full_simplify(IRNode *root,
const CompileConfig &config,
Kernel *kernel = nullptr);
void full_simplify(IRNode *root, Kernel *kernel = nullptr);
void print(IRNode *root, std::string *output = nullptr);
void lower(IRNode *root);
void typecheck(IRNode *root, Kernel *kernel = nullptr);
void typecheck(IRNode *root);
void loop_vectorize(IRNode *root);
void slp_vectorize(IRNode *root);
void vector_split(IRNode *root, int max_width, bool serial_schedule);
Expand Down
6 changes: 3 additions & 3 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void ExecutionQueue::enqueue(KernelLaunchRecord &&ker) {
flag_access(stmt);
lower_access(stmt, true, kernel);
flag_access(stmt);
full_simplify(stmt, kernel->program.config, kernel);
full_simplify(stmt, kernel);
// analysis::verify(stmt);
}
auto func = CodeGenCPU(kernel, stmt).codegen();
Expand Down Expand Up @@ -108,7 +108,7 @@ ExecutionQueue::ExecutionQueue()
void AsyncEngine::launch(Kernel *kernel) {
if (!kernel->lowered)
kernel->lower(false);
auto block = dynamic_cast<Block *>(kernel->ir);
auto block = dynamic_cast<Block *>(kernel->ir.get());
TI_ASSERT(block);
auto &offloads = block->statements;
for (std::size_t i = 0; i < offloads.size(); i++) {
Expand Down Expand Up @@ -266,7 +266,7 @@ bool AsyncEngine::fuse() {
irpass::fix_block_parents(task_a);

auto kernel = task_queue[i].kernel;
irpass::full_simplify(task_a, kernel->program.config, kernel);
irpass::full_simplify(task_a, kernel);
task_queue[i].h = hash(task_a);

modified = true;
Expand Down
6 changes: 3 additions & 3 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ Kernel::Kernel(Program &program,
is_evaluator = false;
compiled = nullptr;
taichi::lang::context = std::make_unique<FrontendContext>();
ir_holder = taichi::lang::context->get_root();
ir = ir_holder.get();
ir = taichi::lang::context->get_root();

{
CurrentKernelGuard _(program, this);
program.start_function_definition(this);
func();
program.end_function_definition();
ir->as<Block>()->kernel = this;
}

arch = program.config.arch;
Expand Down Expand Up @@ -74,7 +74,7 @@ void Kernel::lower(bool lower_access) { // TODO: is a "Lowerer" class necessary
if (is_accessor && !config.print_accessor_ir)
verbose = false;
irpass::compile_to_offloads(
ir, config, /*vectorize*/ arch_is_cpu(arch), grad,
ir.get(), config, /*vectorize*/ arch_is_cpu(arch), grad,
/*ad_use_stack*/ true, verbose, /*lower_global_access*/ lower_access);
} else {
TI_NOT_IMPLEMENTED
Expand Down
3 changes: 1 addition & 2 deletions taichi/program/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ class Program;

class Kernel {
public:
std::unique_ptr<IRNode> ir_holder;
IRNode *ir;
std::unique_ptr<IRNode> ir;
Program &program;
FunctionType compiled;
std::string name;
Expand Down
3 changes: 2 additions & 1 deletion taichi/transforms/alg_simp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ class AlgSimp : public BasicStmtVisitor {

namespace irpass {

bool alg_simp(IRNode *root, const CompileConfig &config) {
bool alg_simp(IRNode *root) {
const auto &config = root->get_kernel()->program.config;
return AlgSimp::run(root, config.fast_math);
}

Expand Down
8 changes: 4 additions & 4 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ void compile_to_offloads(IRNode *ir,

if (grad) {
irpass::demote_atomics(ir);
irpass::full_simplify(ir, config);
irpass::full_simplify(ir);
irpass::make_adjoint(ir, ad_use_stack);
irpass::full_simplify(ir, config);
irpass::full_simplify(ir);
print("Adjoint");
irpass::analysis::verify(ir);
}
Expand Down Expand Up @@ -91,7 +91,7 @@ void compile_to_offloads(IRNode *ir,
irpass::analysis::verify(ir);
}

irpass::full_simplify(ir, config);
irpass::full_simplify(ir);
print("Simplified II");
irpass::analysis::verify(ir);

Expand Down Expand Up @@ -122,7 +122,7 @@ void compile_to_offloads(IRNode *ir,
irpass::variable_optimization(ir, true);
print("Store forwarded II");

irpass::full_simplify(ir, config);
irpass::full_simplify(ir);
print("Simplified III");

// Final field registration correctness & type checking
Expand Down
13 changes: 8 additions & 5 deletions taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,14 @@ class ConstantFold : public BasicStmtVisitor {
rhs.dt,
true};
auto *ker = get_jit_evaluator_kernel(id);
auto &ctx = get_current_program().get_context();
auto &current_program = stmt->get_kernel()->program;
auto &ctx = current_program.get_context();
Copy link
Collaborator Author

@TH3CHARLie TH3CHARLie May 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This segment has some issues with it. Please see comment below

ContextArgSaveGuard _(
ctx); // save input args, prevent override current kernel
ctx.set_arg<int64_t>(0, lhs.val_i64);
ctx.set_arg<int64_t>(1, rhs.val_i64);
(*ker)();
ret.val_i64 = get_current_program().fetch_result<int64_t>(0);
ret.val_i64 = current_program.fetch_result<int64_t>(0);
return true;
}

Expand All @@ -135,12 +136,13 @@ class ConstantFold : public BasicStmtVisitor {
stmt->cast_type,
false};
auto *ker = get_jit_evaluator_kernel(id);
auto &ctx = get_current_program().get_context();
auto &current_program = stmt->get_kernel()->program;
auto &ctx = current_program.get_context();
ContextArgSaveGuard _(
ctx); // save input args, prevent override current kernel
ctx.set_arg<int64_t>(0, operand.val_i64);
(*ker)();
ret.val_i64 = get_current_program().fetch_result<int64_t>(0);
ret.val_i64 = current_program.fetch_result<int64_t>(0);
return true;
}

Expand Down Expand Up @@ -204,7 +206,8 @@ void constant_fold(IRNode *root) {
// disable constant_fold when config.debug is turned on.
// Discussion:
// https://github.com/taichi-dev/taichi/pull/839#issuecomment-626107010
if (get_current_program().config.debug) {
auto kernel = root->get_kernel();
if (kernel && kernel->program.config.debug) {
TI_TRACE("config.debug enabled, ignoring constant fold");
return;
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/lower_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ namespace irpass {

void lower_access(IRNode *root, bool lower_atomic, Kernel *kernel) {
LowerAccess::run(root, lower_atomic);
typecheck(root, kernel);
typecheck(root);
}

} // namespace irpass
Expand Down
6 changes: 3 additions & 3 deletions taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ class BasicBlockSimplify : public IRVisitor {
stmt->insert_before_me(std::move(sum));
stmt->parent->erase(stmt);
// get types of adds and muls
irpass::typecheck(stmt->parent, kernel);
irpass::typecheck(stmt->parent);
throw IRModified();
}

Expand Down Expand Up @@ -1160,10 +1160,10 @@ void simplify(IRNode *root, Kernel *kernel) {
}
}

void full_simplify(IRNode *root, const CompileConfig &config, Kernel *kernel) {
void full_simplify(IRNode *root, Kernel *kernel) {
constant_fold(root);
if (advanced_optimization) {
alg_simp(root, config);
alg_simp(root);
die(root);
whole_kernel_cse(root);
}
Expand Down
17 changes: 9 additions & 8 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ class TypeCheck : public IRVisitor {
CompileConfig config;

public:
TypeCheck(Kernel *kernel) : kernel(kernel) {
// TODO: remove dependency on get_current_program here
if (current_program != nullptr)
config = get_current_program().config;
TypeCheck(IRNode *root) {
kernel = root->get_kernel();
if (kernel != nullptr) {
config = kernel->program.config;
TH3CHARLie marked this conversation as resolved.
Show resolved Hide resolved
}
allow_undefined_visitor = true;
}

Expand Down Expand Up @@ -316,7 +317,7 @@ class TypeCheck : public IRVisitor {
void visit(ArgLoadStmt *stmt) {
Kernel *current_kernel = kernel;
if (current_kernel == nullptr) {
current_kernel = &get_current_program().get_current_kernel();
current_kernel = stmt->get_kernel();
}
auto &args = current_kernel->args;
TI_ASSERT(0 <= stmt->arg_id && stmt->arg_id < args.size());
Expand All @@ -326,7 +327,7 @@ class TypeCheck : public IRVisitor {
void visit(KernelReturnStmt *stmt) {
Kernel *current_kernel = kernel;
if (current_kernel == nullptr) {
current_kernel = &get_current_program().get_current_kernel();
current_kernel = stmt->get_kernel();
}
auto &rets = current_kernel->rets;
TI_ASSERT(rets.size() >= 1);
Expand Down Expand Up @@ -416,9 +417,9 @@ class TypeCheck : public IRVisitor {

namespace irpass {

void typecheck(IRNode *root, Kernel *kernel) {
void typecheck(IRNode *root) {
analysis::check_fields_registered(root);
TypeCheck inst(kernel);
TypeCheck inst(root);
root->accept(&inst);
}

Expand Down
Loading