Skip to content

Commit

Permalink
[llvm] [lang] Add support for multiple return statements in real func…
Browse files Browse the repository at this point in the history
…tion (#4536)
  • Loading branch information
lin-hitonami authored Mar 18, 2022
1 parent f40725d commit c57b8d2
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 11 deletions.
42 changes: 35 additions & 7 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,20 @@ FunctionCreationGuard::FunctionCreationGuard(
}

FunctionCreationGuard::~FunctionCreationGuard() {
mb->builder->CreateRetVoid();
if (!mb->returned) {
mb->builder->CreateRetVoid();
}
mb->func = old_func;
mb->builder->restoreIP(ip);
mb->returned = false;

{
llvm::IRBuilderBase::InsertPointGuard gurad(*mb->builder);
mb->builder->SetInsertPoint(allocas);
mb->builder->CreateBr(entry);
mb->entry_block = old_entry;
}
TI_ASSERT(!llvm::verifyFunction(*body, &llvm::errs()));
}

namespace {
Expand Down Expand Up @@ -127,6 +131,9 @@ CodeGenStmtGuard make_while_after_loop_guard(CodeGenLLVM *cg) {
void CodeGenLLVM::visit(Block *stmt_list) {
for (auto &stmt : stmt_list->statements) {
stmt->accept(this);
if (returned) {
break;
}
}
}

Expand Down Expand Up @@ -730,12 +737,20 @@ void CodeGenLLVM::visit(IfStmt *if_stmt) {
if (if_stmt->true_statements) {
if_stmt->true_statements->accept(this);
}
builder->CreateBr(after_if);
if (!returned) {
builder->CreateBr(after_if);
} else {
returned = false;
}
builder->SetInsertPoint(false_block);
if (if_stmt->false_statements) {
if_stmt->false_statements->accept(this);
}
builder->CreateBr(after_if);
if (!returned) {
builder->CreateBr(after_if);
} else {
returned = false;
}
builder->SetInsertPoint(after_if);
}

Expand Down Expand Up @@ -906,7 +921,11 @@ void CodeGenLLVM::visit(WhileStmt *stmt) {

stmt->body->accept(this);

builder->CreateBr(body); // jump to head
if (!returned) {
builder->CreateBr(body); // jump to head
} else {
returned = false;
}

builder->SetInsertPoint(after_loop);
}
Expand Down Expand Up @@ -1001,8 +1020,11 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {

for_stmt->body->accept(this);
}

builder->CreateBr(loop_inc);
if (!returned) {
builder->CreateBr(loop_inc);
} else {
returned = false;
}
builder->SetInsertPoint(loop_inc);

if (!for_stmt->reversed) {
Expand Down Expand Up @@ -1098,6 +1120,8 @@ void CodeGenLLVM::visit(ReturnStmt *stmt) {
tlctx->get_constant<int32>(idx++)});
}
}
builder->CreateRetVoid();
returned = true;
}

void CodeGenLLVM::visit(LocalLoadStmt *stmt) {
Expand Down Expand Up @@ -1653,7 +1677,11 @@ std::string CodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt,
}

void CodeGenLLVM::finalize_offloaded_task_function() {
builder->CreateRetVoid();
if (!returned) {
builder->CreateRetVoid();
} else {
returned = false;
}

// entry_block should jump to the body after all allocas are inserted
builder->SetInsertPoint(entry_block);
Expand Down
1 change: 1 addition & 0 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
std::vector<OffloadedTask> offloaded_tasks;
llvm::BasicBlock *func_body_bb;
std::set<std::string> linked_modules;
bool returned{false};

std::unordered_map<const Stmt *, std::vector<llvm::Value *>> loop_vars_llvm;

Expand Down
78 changes: 74 additions & 4 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,10 @@ def run() -> ti.i32:
def test_recursion():
@ti.experimental.real_func
def sum(f: ti.template(), l: ti.i32, r: ti.i32) -> ti.i32:
ret = 0
if l == r:
ret = f[l]
return f[l]
else:
ret = sum(f, l, (l + r) // 2) + sum(f, (l + r) // 2 + 1, r)
return ret
return sum(f, l, (l + r) // 2) + sum(f, (l + r) // 2 + 1, r)

f = ti.field(ti.i32, shape=100)
for i in range(100):
Expand All @@ -234,3 +232,75 @@ def get_sum() -> ti.i32:
return sum(f, 0, 99)

assert get_sum() == 99 * 50


@test_utils.test(arch=[ti.cpu, ti.gpu])
def test_multiple_return():
x = ti.field(ti.i32, shape=())

@ti.experimental.real_func
def foo(val: ti.i32) -> ti.i32:
if x[None] > 10:
if x[None] > 20:
return 1
x[None] += 1
x[None] += val
return 0

@ti.kernel
def run():
assert foo(15) == 0
assert foo(10) == 0
assert foo(100) == 1

x[None] = 0
run()
assert x[None] == 26


@test_utils.test(arch=[ti.cpu, ti.gpu])
def test_return_in_for():
@ti.experimental.real_func
def foo() -> ti.i32:
for i in range(10):
return 42

@ti.kernel
def bar() -> ti.i32:
return foo()

assert bar() == 42


@test_utils.test(arch=[ti.cpu, ti.gpu])
def test_return_in_while():
@ti.experimental.real_func
def foo() -> ti.i32:
i = 1
while i:
return 42

@ti.kernel
def bar() -> ti.i32:
return foo()

assert bar() == 42


@test_utils.test(arch=[ti.cpu, ti.gpu])
def test_return_in_if_in_for():
@ti.experimental.real_func
def foo(a: ti.i32) -> ti.i32:
s = 0
for i in range(100):
if i == a + 1:
return s
s = s + i
return s

@ti.kernel
def bar(a: ti.i32) -> ti.i32:
return foo(a)

assert bar(10) == 11 * 5
assert bar(200) == 99 * 50

0 comments on commit c57b8d2

Please sign in to comment.