Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] [IR] Kernel scalar return support (ArgStoreStmt -> KernelReturnStmt) #917

Merged
merged 25 commits into from
May 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion python/taichi/lang/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(self, func, is_grad, classkernel=False):
self.is_grad = is_grad
self.arguments = []
self.argument_names = []
self.return_type = None
self.classkernel = classkernel
self.extract_arguments()
self.template_slot_locations = []
Expand All @@ -180,6 +181,8 @@ def reset(self):

def extract_arguments(self):
sig = inspect.signature(self.func)
if sig.return_annotation not in (inspect._empty, None):
self.return_type = sig.return_annotation
params = sig.parameters
arg_names = params.keys()
for i, arg_name in enumerate(arg_names):
Expand Down Expand Up @@ -249,6 +252,9 @@ def materialize(self, key=None, args=None, arg_features=None):
if isinstance(anno, ast.Name):
global_vars[anno.id] = self.arguments[i]

if isinstance(func_body.returns, ast.Name):
global_vars[func_body.returns.id] = self.return_type

if self.is_grad:
from .ast_checker import KernelSimplicityASTChecker
KernelSimplicityASTChecker(self.func).visit(tree)
Expand Down Expand Up @@ -388,12 +394,22 @@ def call_back():

t_kernel()

ret = None
ret_dt = self.return_type
if ret_dt is not None:
if taichi_lang_core.is_integral(ret_dt):
ret = t_kernel.get_ret_int(0)
else:
ret = t_kernel.get_ret_float(0)

if callbacks:
import taichi as ti
ti.sync()
for c in callbacks:
c()

return ret

return func__

def match_ext_arr(self, v, needed):
Expand Down Expand Up @@ -481,7 +497,7 @@ def wrapped(*args, **kwargs):

@wraps(func)
def wrapped(*args, **kwargs):
primal(*args, **kwargs)
return primal(*args, **kwargs)

wrapped.grad = adjoint

Expand Down
5 changes: 5 additions & 0 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ def decl_scalar_arg(dt):
def decl_ext_arr_arg(dt, dim):
id = taichi_lang_core.decl_arg(dt, True)
return Expr(taichi_lang_core.make_external_tensor_expr(dt, dim, id))


def decl_scalar_ret(dt):
id = taichi_lang_core.decl_ret(dt)
return id
28 changes: 25 additions & 3 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self,
self.is_classfunc = is_classfunc
self.func = func
self.arg_features = arg_features
self.returns = None

def variable_scope(self, *args):
return ScopeGuard(self, *args)
Expand Down Expand Up @@ -594,6 +595,15 @@ def visit_FunctionDef(self, node):
"Function definition not allowed in 'ti.kernel'.")
# Transform as kernel
arg_decls = []

# Treat return type
if node.returns is not None:
ret_init = self.parse_stmt('ti.decl_scalar_ret(0)')
ret_init.value.args[0] = node.returns
self.returns = node.returns
arg_decls.append(ret_init)
node.returns = None

for i, arg in enumerate(args.args):
if isinstance(self.func.arguments[i], ti.template):
continue
Expand All @@ -620,6 +630,7 @@ def visit_FunctionDef(self, node):
arg_decls.append(arg_init)
# remove original args
node.args.args = []

else: # ti.func
for decorator in node.decorator_list:
if (isinstance(decorator, ast.Attribute)
Expand All @@ -640,8 +651,10 @@ def visit_FunctionDef(self, node):
'_by_value__')
args.args[i].arg += '_by_value__'
arg_decls.append(arg_init)

with self.variable_scope():
self.generic_visit(node)

node.body = arg_decls + node.body
return node

Expand Down Expand Up @@ -736,7 +749,16 @@ def visit_Assert(self, node):
def visit_Return(self, node):
self.generic_visit(node)
if self.is_kernel:
raise TaichiSyntaxError(
'"return" not allowed in \'ti.kernel\'. Please walk around by storing the return result to a global variable.'
)
# TODO: check if it's at the end of a kernel, throw TaichiSyntaxError if not
if node.value is not None:
if self.returns is None:
raise TaichiSyntaxError('kernel with return value must be '
'annotated with a return type, e.g. def func() -> ti.f32')
ret_expr = self.parse_expr('ti.cast(ti.Expr(0), 0)')
ret_expr.args[0].args[0] = node.value
ret_expr.args[1] = self.returns
ret_stmt = self.parse_stmt(
'ti.core.create_kernel_return(ret.ptr)')
ret_stmt.value.args[0].value = ret_expr
return ret_stmt
return node
9 changes: 9 additions & 0 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,15 @@ class KernelGen : public IRVisitor {
const_stmt->short_name(), const_stmt->val[0].stringify());
}

void visit(KernelReturnStmt *stmt) override {
used.argument = true;
used.int64 = true;
// TODO: consider use _rets_{}_ instead of _args_{}_
// TODO: use stmt->ret_id instead of 0 as index
emit("_args_{}_[0] = {};", data_type_short_name(stmt->element_type()),
stmt->value->short_name());
}

void visit(ArgLoadStmt *stmt) override {
const auto dt = opengl_data_type_name(stmt->element_type());
used.argument = true;
Expand Down
5 changes: 3 additions & 2 deletions taichi/backends/opengl/opengl_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,12 +366,13 @@ struct CompiledKernel {

struct CompiledProgram::Impl {
std::vector<std::unique_ptr<CompiledKernel>> kernels;
int arg_count;
int arg_count, ret_count;
std::map<int, size_t> ext_arr_map;
size_t gtmp_size;

Impl(Kernel *kernel, size_t gtmp_size) : gtmp_size(gtmp_size) {
arg_count = kernel->args.size();
ret_count = kernel->rets.size();
for (int i = 0; i < arg_count; i++) {
if (kernel->args[i].is_nparray) {
ext_arr_map[i] = kernel->args[i].size;
Expand All @@ -390,7 +391,7 @@ struct CompiledProgram::Impl {

void launch(Context &ctx, GLSLLauncher *launcher) const {
std::vector<IOV> iov;
iov.push_back(IOV{ctx.args, arg_count * sizeof(uint64_t)});
iov.push_back(IOV{ctx.args, std::max(arg_count, ret_count) * sizeof(uint64_t)});
auto gtmp_arr = std::vector<char>(gtmp_size);
void *gtmp_base = gtmp_arr.data(); // std::calloc(gtmp_size, 1);
iov.push_back(IOV{gtmp_base, gtmp_size});
Expand Down
18 changes: 18 additions & 0 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,24 @@ void CodeGenLLVM::visit(ArgStoreStmt *stmt) {
}
}

void CodeGenLLVM::visit(KernelReturnStmt *stmt) {
if (stmt->is_ptr) {
TI_NOT_IMPLEMENTED
} else {
auto intermediate_bits =
tlctx->get_data_type(stmt->value->ret_type.data_type)
->getPrimitiveSizeInBits();
llvm::Type *intermediate_type =
llvm::Type::getIntNTy(*llvm_context, intermediate_bits);
llvm::Type *dest_ty = tlctx->get_data_type<int64>();
auto extended = builder->CreateZExt(
builder->CreateBitCast(llvm_val[stmt->value], intermediate_type),
dest_ty);
builder->CreateCall(get_runtime_function("LLVMRuntime_store_result"),
{get_runtime(), extended});
}
}

void CodeGenLLVM::visit(LocalLoadStmt *stmt) {
TI_ASSERT(stmt->width() == 1);
llvm_val[stmt] = builder->CreateLoad(llvm_val[stmt->ptr[0].var]);
Expand Down
2 changes: 2 additions & 0 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(ArgStoreStmt *stmt) override;

void visit(KernelReturnStmt *stmt) override;

void visit(LocalLoadStmt *stmt) override;

void visit(LocalStoreStmt *stmt) override;
Expand Down
2 changes: 2 additions & 0 deletions taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ PER_STATEMENT(FrontendSNodeOpStmt) // activate, deactivate, append, clear
PER_STATEMENT(FrontendAssertStmt)
PER_STATEMENT(FrontendArgStoreStmt)
PER_STATEMENT(FrontendFuncDefStmt)
PER_STATEMENT(FrontendKernelReturnStmt)

// Middle-end statement

Expand All @@ -24,6 +25,7 @@ PER_STATEMENT(WhileControlStmt)
PER_STATEMENT(ContinueStmt)
PER_STATEMENT(FuncBodyStmt)
PER_STATEMENT(FuncCallStmt)
PER_STATEMENT(KernelReturnStmt)

PER_STATEMENT(ArgLoadStmt)
PER_STATEMENT(ExternalPtrStmt)
Expand Down
14 changes: 14 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,20 @@ class FrontendWhileStmt : public Stmt {
DEFINE_ACCEPT
};

class FrontendKernelReturnStmt : public Stmt {
public:
Expr value;

FrontendKernelReturnStmt(const Expr &value) : value(value) {
}

bool is_container_statement() const override {
return false;
}

DEFINE_ACCEPT
};

// Expressions

class ArgLoadExpression : public Expression {
Expand Down
16 changes: 16 additions & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,22 @@ class FuncCallStmt : public Stmt {
DEFINE_ACCEPT
};

class KernelReturnStmt : public Stmt {
public:
Stmt *value;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put multi-return in another PR.


KernelReturnStmt(Stmt *value) : value(value) {
TI_STMT_REG_FIELDS;
}

bool is_container_statement() const override {
return false;
}

TI_STMT_DEF_FIELDS(value);
DEFINE_ACCEPT
};

class WhileStmt : public Stmt {
public:
Stmt *mask;
Expand Down
51 changes: 4 additions & 47 deletions taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,25 +122,6 @@ void SNode::write_float(const std::vector<int> &I, float64 val) {
(*writer_kernel)();
}

uint64 SNode::fetch_reader_result() {
uint64 ret;
auto arch = get_current_program().config.arch;
if (arch == Arch::cuda) {
// TODO: refactor
#if defined(TI_WITH_CUDA)
CUDADriver::get_instance().memcpy_device_to_host(
&ret, get_current_program().result_buffer, sizeof(uint64));
#else
TI_NOT_IMPLEMENTED;
#endif
} else if (arch_is_cpu(arch)) {
ret = *(uint64 *)get_current_program().result_buffer;
} else {
ret = get_current_program().context.get_arg_as_uint64(num_active_indices);
}
return ret;
}

float64 SNode::read_float(const std::vector<int> &I) {
if (reader_kernel == nullptr) {
reader_kernel = &get_current_program().get_snode_reader(this);
Expand All @@ -149,14 +130,8 @@ float64 SNode::read_float(const std::vector<int> &I) {
get_current_program().synchronize();
(*reader_kernel)();
get_current_program().synchronize();
auto ret = fetch_reader_result();
if (dt == DataType::f32) {
return taichi_union_cast_with_different_sizes<float32>(ret);
} else if (dt == DataType::f64) {
return taichi_union_cast_with_different_sizes<float64>(ret);
} else {
TI_NOT_IMPLEMENTED
}
auto ret = reader_kernel->get_ret_float(0);
return ret;
}

// for int32 and int64
Expand All @@ -178,26 +153,8 @@ int64 SNode::read_int(const std::vector<int> &I) {
get_current_program().synchronize();
(*reader_kernel)();
get_current_program().synchronize();
auto ret = fetch_reader_result();
if (dt == DataType::i32) {
return taichi_union_cast_with_different_sizes<int32>(ret);
} else if (dt == DataType::i64) {
return taichi_union_cast_with_different_sizes<int64>(ret);
} else if (dt == DataType::i8) {
return taichi_union_cast_with_different_sizes<int8>(ret);
} else if (dt == DataType::i16) {
return taichi_union_cast_with_different_sizes<int16>(ret);
} else if (dt == DataType::u8) {
return taichi_union_cast_with_different_sizes<uint8>(ret);
} else if (dt == DataType::u16) {
return taichi_union_cast_with_different_sizes<uint16>(ret);
} else if (dt == DataType::u32) {
return taichi_union_cast_with_different_sizes<uint32>(ret);
} else if (dt == DataType::u64) {
return taichi_union_cast_with_different_sizes<uint64>(ret);
} else {
TI_NOT_IMPLEMENTED
}
auto ret = reader_kernel->get_ret_int(0);
return ret;
}

uint64 SNode::read_uint(const std::vector<int> &I) {
Expand Down
Loading