Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR][refactor] Convert loop_var into LoopIndexStmt #953

Merged
merged 22 commits into from
May 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/hello.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ Let's dive into this simple Taichi program.

import taichi as ti
-------------------
Taichi is a domain-specific language (DSL) embedded in Python. To make Taichi as easy to use as a Python package,
we have done heavy engineering with this goal in mind - letting every Python programmer write Taichi codes with
minimal learning effort. You can even use your favorite Python package management system, Python IDEs and other
Taichi is a domain-specific language (DSL) embedded in Python. To make Taichi as easy to use as a Python package,
we have done heavy engineering with this goal in mind - letting every Python programmer write Taichi codes with
minimal learning effort. You can even use your favorite Python package management system, Python IDEs and other
Python packages in conjunction with Taichi.

Portability
Expand Down
46 changes: 32 additions & 14 deletions taichi/analysis/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,46 @@ class IRVerifier : public BasicStmtVisitor {
TI_ASSERT(stmt->ptr->is<AllocaStmt>());
}

void visit(LoopIndexStmt *stmt) override {
basic_verify(stmt);
TI_ASSERT(stmt->loop);
if (stmt->loop->is<OffloadedStmt>()) {
TI_ASSERT(stmt->loop->as<OffloadedStmt>()->task_type ==
OffloadedStmt::TaskType::struct_for ||
stmt->loop->as<OffloadedStmt>()->task_type ==
OffloadedStmt::TaskType::range_for);
} else {
TI_ASSERT(stmt->loop->is<StructForStmt>() ||
stmt->loop->is<RangeForStmt>());
}
}

void visit(RangeForStmt *for_stmt) override {
basic_verify(for_stmt);
TI_ASSERT(for_stmt->loop_var->is<AllocaStmt>());
TI_ASSERT_INFO(irpass::analysis::gather_statements(
for_stmt->loop_var->parent,
[&](Stmt *s) {
if (auto store = s->cast<LocalStoreStmt>())
return store->ptr == for_stmt->loop_var;
else if (auto atomic = s->cast<AtomicOpStmt>()) {
return atomic->dest == for_stmt->loop_var;
} else {
return false;
}
})
.empty(),
"loop_var of {} modified", for_stmt->id);
if (for_stmt->loop_var) {
TI_ASSERT(for_stmt->loop_var->is<AllocaStmt>());
TI_ASSERT_INFO(irpass::analysis::gather_statements(
for_stmt->loop_var->parent,
[&](Stmt *s) {
if (auto store = s->cast<LocalStoreStmt>())
return store->ptr == for_stmt->loop_var;
else if (auto atomic = s->cast<AtomicOpStmt>()) {
return atomic->dest == for_stmt->loop_var;
} else {
return false;
}
})
.empty(),
"loop_var of {} modified", for_stmt->id);
}
for_stmt->body->accept(this);
}

void visit(StructForStmt *for_stmt) override {
basic_verify(for_stmt);
for (auto &loop_var : for_stmt->loop_vars) {
if (!loop_var)
continue;
TI_ASSERT(loop_var->is<AllocaStmt>());
TI_ASSERT_INFO(irpass::analysis::gather_statements(
loop_var->parent,
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CodeGenLLVMCPU : public CodeGenLLVM {
tlctx->get_data_type<int>()});

auto loop_var = create_entry_block_alloca(DataType::i32);
offloaded_loop_vars_llvm[stmt].push_back(loop_var);
loop_vars_llvm[stmt].push_back(loop_var);
builder->CreateStore(get_arg(1), loop_var);
stmt->body->accept(this);

Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
tlctx->get_data_type<int>()});

auto loop_var = create_entry_block_alloca(DataType::i32);
offloaded_loop_vars_llvm[stmt].push_back(loop_var);
loop_vars_llvm[stmt].push_back(loop_var);
builder->CreateStore(get_arg(1), loop_var);
stmt->body->accept(this);

Expand Down
54 changes: 25 additions & 29 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,22 @@ class KernelCodegen : public IRVisitor {
}

void visit(LoopIndexStmt *stmt) override {
using TaskType = OffloadedStmt::TaskType;
const auto type = current_kernel_attribs_->task_type;
const auto stmt_name = stmt->raw_name();
if (type == TaskType::range_for) {
if (stmt->loop->is<OffloadedStmt>()) {
using TaskType = OffloadedStmt::TaskType;
const auto type = stmt->loop->as<OffloadedStmt>()->task_type;
if (type == TaskType::range_for) {
TI_ASSERT(stmt->index == 0);
emit("const int {} = {};", stmt_name, kLinearLoopIndexName);
} else if (type == TaskType::struct_for) {
emit("const int {} = {}.coords[{}];", stmt_name, kListgenElemVarName,
stmt->index);
} else {
TI_NOT_IMPLEMENTED;
}
} else if (stmt->loop->is<RangeForStmt>()) {
TI_ASSERT(stmt->index == 0);
emit("const int {} = {};", stmt_name, kLinearLoopIndexName);
} else if (type == TaskType::struct_for) {
emit("const int {} = {}.coords[{}];", stmt_name, kListgenElemVarName,
stmt->index);
emit("const int {} = {};", stmt_name, stmt->loop->raw_name());
} else {
TI_NOT_IMPLEMENTED;
}
Expand Down Expand Up @@ -445,29 +452,18 @@ class KernelCodegen : public IRVisitor {

void visit(RangeForStmt *for_stmt) override {
TI_ASSERT(for_stmt->width() == 1);
auto *loop_var = for_stmt->loop_var;
if (loop_var->ret_type.data_type == DataType::i32) {
if (!for_stmt->reversed) {
emit("for (int {}_ = {}; {}_ < {}; {}_ = {}_ + {}) {{",
loop_var->raw_name(), for_stmt->begin->raw_name(),
loop_var->raw_name(), for_stmt->end->raw_name(),
loop_var->raw_name(), loop_var->raw_name(), 1);
emit(" int {} = {}_;", loop_var->raw_name(), loop_var->raw_name());
} else {
// reversed for loop
emit("for (int {}_ = {} - 1; {}_ >= {}; {}_ = {}_ - {}) {{",
loop_var->raw_name(), for_stmt->end->raw_name(),
loop_var->raw_name(), for_stmt->begin->raw_name(),
loop_var->raw_name(), loop_var->raw_name(), 1);
emit(" int {} = {}_;", loop_var->raw_name(), loop_var->raw_name());
}
auto loop_var_name = for_stmt->raw_name();
if (!for_stmt->reversed) {
emit("for (int {}_ = {}; {}_ < {}; {}_ = {}_ + {}) {{", loop_var_name,
for_stmt->begin->raw_name(), loop_var_name,
for_stmt->end->raw_name(), loop_var_name, loop_var_name, 1);
emit(" int {} = {}_;", loop_var_name, loop_var_name);
} else {
TI_ASSERT(!for_stmt->reversed);
const auto type_name = metal_data_type_name(loop_var->element_type());
emit("for ({} {} = {}; {} < {}; {} = {} + ({})1) {{", type_name,
loop_var->raw_name(), for_stmt->begin->raw_name(),
loop_var->raw_name(), for_stmt->end->raw_name(),
loop_var->raw_name(), loop_var->raw_name(), type_name);
// reversed for loop
emit("for (int {}_ = {} - 1; {}_ >= {}; {}_ = {}_ - {}) {{",
loop_var_name, for_stmt->end->raw_name(), loop_var_name,
for_stmt->begin->raw_name(), loop_var_name, loop_var_name, 1);
emit(" int {} = {}_;", loop_var_name, loop_var_name);
}
for_stmt->body->accept(this);
emit("}}");
Expand Down
5 changes: 3 additions & 2 deletions taichi/backends/metal/shaders/helpers.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ STR(

inline int ifloordiv(int lhs, int rhs) {
const int intm = (lhs / rhs);
return (((lhs < 0) != (rhs < 0) && lhs &&
(rhs * intm != lhs)) ? (intm - 1) : intm);
return (((lhs < 0) != (rhs < 0) && lhs && (rhs * intm != lhs))
? (intm - 1)
: intm);
}

int32_t pow_i32(int32_t x, int32_t n) {
Expand Down
71 changes: 33 additions & 38 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,14 +414,13 @@ class KernelGen : public IRVisitor {
emit("{} {} = atan({}, {});", dt_name, bin_name, lhs_name, rhs_name);
}
return;
} else if (bin->op_type == BinaryOpType::pow
&& is_integral(bin->rhs->element_type())) {
// The GLSL `pow` is not so percise for `int`... e.g.: `pow(5, 3)` obtains 124
// So that we have to use some hack to make it percise.
// Discussion: https://github.com/taichi-dev/taichi/pull/943#issuecomment-626354902
} else if (bin->op_type == BinaryOpType::pow &&
is_integral(bin->rhs->element_type())) {
// The GLSL `pow` is not so percise for `int`... e.g.: `pow(5, 3)` obtains
// 124 So that we have to use some hack to make it percise. Discussion:
// https://github.com/taichi-dev/taichi/pull/943#issuecomment-626354902
emit("{} {} = {}(fast_pow_{}({}, {}));", dt_name, bin_name, dt_name,
data_type_short_name(bin->lhs->element_type()),
lhs_name, rhs_name);
data_type_short_name(bin->lhs->element_type()), lhs_name, rhs_name);
used.fast_pow = true;
return;
}
Expand Down Expand Up @@ -602,38 +601,32 @@ class KernelGen : public IRVisitor {
}

void visit(LoopIndexStmt *stmt) override {
TI_ASSERT(!stmt->is_struct_for);
TI_ASSERT(stmt->index == 0); // TODO: multiple indices
emit("int {} = _itv;", stmt->short_name());
if (stmt->loop->is<OffloadedStmt>()) {
TI_ASSERT(stmt->loop->as<OffloadedStmt>()->task_type ==
OffloadedStmt::TaskType::range_for);
emit("int {} = _itv;", stmt->short_name());
} else if (stmt->loop->is<RangeForStmt>()) {
emit("int {} = {};", stmt->short_name(), stmt->loop->short_name());
} else {
TI_NOT_IMPLEMENTED
}
}

void visit(RangeForStmt *for_stmt) override {
TI_ASSERT(for_stmt->width() == 1);
auto *loop_var = for_stmt->loop_var;
if (loop_var->ret_type.data_type == DataType::i32) {
if (!for_stmt->reversed) {
emit("for (int {}_ = {}; {}_ < {}; {}_ = {}_ + {}) {{",
loop_var->short_name(), for_stmt->begin->short_name(),
loop_var->short_name(), for_stmt->end->short_name(),
loop_var->short_name(), loop_var->short_name(), 1);
// variable named `loop_var->short_name()` is already allocated by
// alloca
emit(" {} = {}_;", loop_var->short_name(), loop_var->short_name());
} else {
// reversed for loop
emit("for (int {}_ = {} - 1; {}_ >= {}; {}_ = {}_ - {}) {{",
loop_var->short_name(), for_stmt->end->short_name(),
loop_var->short_name(), for_stmt->begin->short_name(),
loop_var->short_name(), loop_var->short_name(), 1);
emit(" {} = {}_;", loop_var->short_name(), loop_var->short_name());
}
auto loop_var_name = for_stmt->short_name();
if (!for_stmt->reversed) {
emit("for (int {}_ = {}; {}_ < {}; {}_ = {}_ + {}) {{", loop_var_name,
for_stmt->begin->short_name(), loop_var_name,
for_stmt->end->short_name(), loop_var_name, loop_var_name, 1);
emit(" int {} = {}_;", loop_var_name, loop_var_name);
} else {
TI_ASSERT(!for_stmt->reversed);
const auto type_name = opengl_data_type_name(loop_var->element_type());
emit("for ({} {} = {}; {} < {}; {} = {} + 1) {{", type_name,
loop_var->short_name(), for_stmt->begin->short_name(),
loop_var->short_name(), for_stmt->end->short_name(),
loop_var->short_name(), loop_var->short_name());
// reversed for loop
emit("for (int {}_ = {} - 1; {}_ >= {}; {}_ = {}_ - {}) {{",
loop_var_name, for_stmt->end->short_name(), loop_var_name,
for_stmt->begin->short_name(), loop_var_name, loop_var_name, 1);
emit(" int {} = {}_;", loop_var_name, loop_var_name);
}
for_stmt->body->accept(this);
emit("}}");
Expand Down Expand Up @@ -730,12 +723,14 @@ void OpenglCodeGen::lower() {
auto ir = kernel_->ir;
auto &config = kernel_->program.config;
config.demote_dense_struct_fors = true;
auto res = irpass::compile_to_offloads(ir, config,
/*vectorize=*/false, kernel_->grad,
/*ad_use_stack=*/false, config.print_ir,
/*lower_global_access*/true);
auto res =
irpass::compile_to_offloads(ir, config,
/*vectorize=*/false, kernel_->grad,
/*ad_use_stack=*/false, config.print_ir,
/*lower_global_access*/ true);
global_tmps_buffer_size_ = res.total_size;
TI_TRACE("[glsl] Global temporary buffer size {} B", global_tmps_buffer_size_);
TI_TRACE("[glsl] Global temporary buffer size {} B",
global_tmps_buffer_size_);
#ifdef _GLSL_DEBUG
irpass::print(ir);
#endif
Expand Down
35 changes: 19 additions & 16 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,13 +790,16 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "after_for", func);
BasicBlock *loop_test =
BasicBlock::Create(*llvm_context, "for_loop_test", func);

auto loop_var = create_entry_block_alloca(DataType::i32);
loop_vars_llvm[for_stmt].push_back(loop_var);

if (!for_stmt->reversed) {
builder->CreateStore(llvm_val[for_stmt->begin],
llvm_val[for_stmt->loop_var]);
builder->CreateStore(llvm_val[for_stmt->begin], loop_var);
} else {
builder->CreateStore(
builder->CreateSub(llvm_val[for_stmt->end], tlctx->get_constant(1)),
llvm_val[for_stmt->loop_var]);
loop_var);
}
builder->CreateBr(loop_test);

Expand All @@ -805,15 +808,13 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
builder->SetInsertPoint(loop_test);
llvm::Value *cond;
if (!for_stmt->reversed) {
cond =
builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SLT,
builder->CreateLoad(llvm_val[for_stmt->loop_var]),
llvm_val[for_stmt->end]);
cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SLT,
builder->CreateLoad(loop_var),
llvm_val[for_stmt->end]);
} else {
cond =
builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SGE,
builder->CreateLoad(llvm_val[for_stmt->loop_var]),
llvm_val[for_stmt->begin]);
cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SGE,
builder->CreateLoad(loop_var),
llvm_val[for_stmt->begin]);
}
builder->CreateCondBr(cond, body, after_loop);
}
Expand All @@ -833,9 +834,9 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
builder->SetInsertPoint(loop_inc);

if (!for_stmt->reversed) {
create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(1));
create_increment(loop_var, tlctx->get_constant(1));
} else {
create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(-1));
create_increment(loop_var, tlctx->get_constant(-1));
}
builder->CreateBr(loop_test);
}
Expand Down Expand Up @@ -1408,13 +1409,15 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) {

void CodeGenLLVM::visit(LoopIndexStmt *stmt) {
TI_ASSERT(&module->getContext() == tlctx->get_this_thread_context());
if (stmt->is_struct_for) {
if (stmt->loop->is<OffloadedStmt>() &&
stmt->loop->as<OffloadedStmt>()->task_type ==
OffloadedStmt::TaskType::struct_for) {
llvm_val[stmt] = builder->CreateLoad(builder->CreateGEP(
current_coordinates, {tlctx->get_constant(0), tlctx->get_constant(0),
tlctx->get_constant(stmt->index)}));
} else {
llvm_val[stmt] = builder->CreateLoad(
offloaded_loop_vars_llvm[current_offloaded_stmt][stmt->index]);
llvm_val[stmt] =
builder->CreateLoad(loop_vars_llvm[stmt->loop][stmt->index]);
}
}

Expand Down
3 changes: 1 addition & 2 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
std::vector<OffloadedTask> offloaded_tasks;
BasicBlock *func_body_bb;

std::unordered_map<OffloadedStmt *, std::vector<llvm::Value *>>
offloaded_loop_vars_llvm;
std::unordered_map<const Stmt *, std::vector<llvm::Value *>> loop_vars_llvm;

using IRVisitor::visit;
using LLVMModuleBuilder::call;
Expand Down
5 changes: 3 additions & 2 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -655,10 +655,11 @@ void Block::replace_statements_in_range(int start,
}

void Block::replace_with(Stmt *old_statement,
std::unique_ptr<Stmt> &&new_statement) {
std::unique_ptr<Stmt> &&new_statement,
bool replace_usages) {
VecStatement vec;
vec.push_back(std::move(new_statement));
replace_with(old_statement, std::move(vec));
replace_with(old_statement, std::move(vec), replace_usages);
}

Stmt *Block::lookup_var(const Identifier &ident) const {
Expand Down
4 changes: 3 additions & 1 deletion taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,9 @@ class Block : public IRNode {
void insert(VecStatement &&stmt, int location = -1);
void replace_statements_in_range(int start, int end, VecStatement &&stmts);
void set_statements(VecStatement &&stmts);
void replace_with(Stmt *old_statement, std::unique_ptr<Stmt> &&new_statement);
void replace_with(Stmt *old_statement,
std::unique_ptr<Stmt> &&new_statement,
bool replace_usages = true);
void insert_before(Stmt *old_statement, VecStatement &&new_statements);
void replace_with(Stmt *old_statement,
VecStatement &&new_statements,
Expand Down
Loading