Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] [IR] Kernel scalar return support (ArgStoreStmt -> KernelReturnStmt) #917

Merged
merged 25 commits into from
May 8, 2020
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion python/taichi/lang/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(self, func, is_grad, classkernel=False):
self.is_grad = is_grad
self.arguments = []
self.argument_names = []
self.return_type = None
self.classkernel = classkernel
self.extract_arguments()
self.template_slot_locations = []
Expand All @@ -180,6 +181,8 @@ def reset(self):

def extract_arguments(self):
sig = inspect.signature(self.func)
if sig.return_annotation not in (inspect._empty, None):
self.return_type = sig.return_annotation
params = sig.parameters
arg_names = params.keys()
for i, arg_name in enumerate(arg_names):
Expand Down Expand Up @@ -388,12 +391,22 @@ def call_back():

t_kernel()

ret = None
ret_dt = self.return_type
if ret_dt is not None:
if taichi_lang_core.is_integral(ret_dt):
ret = t_kernel.get_ret_int(0)
else:
ret = t_kernel.get_ret_float(0)

if callbacks:
import taichi as ti
ti.sync()
for c in callbacks:
c()

return ret

return func__

def match_ext_arr(self, v, needed):
Expand Down Expand Up @@ -481,7 +494,7 @@ def wrapped(*args, **kwargs):

@wraps(func)
def wrapped(*args, **kwargs):
primal(*args, **kwargs)
return primal(*args, **kwargs)

wrapped.grad = adjoint

Expand Down
5 changes: 5 additions & 0 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ def decl_scalar_arg(dt):
def decl_ext_arr_arg(dt, dim):
id = taichi_lang_core.decl_arg(dt, True)
return Expr(taichi_lang_core.make_external_tensor_expr(dt, dim, id))


def decl_scalar_ret(dt):
id = taichi_lang_core.decl_ret(dt)
return id
26 changes: 23 additions & 3 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self,
self.is_classfunc = is_classfunc
self.func = func
self.arg_features = arg_features
self.returns = None

def variable_scope(self, *args):
return ScopeGuard(self, *args)
Expand Down Expand Up @@ -594,6 +595,14 @@ def visit_FunctionDef(self, node):
"Function definition not allowed in 'ti.kernel'.")
# Transform as kernel
arg_decls = []

# Treat return type
if node.returns is not None:
ret_init = self.parse_stmt('ti.decl_scalar_ret(0)')
ret_init.value.args[0] = node.returns
self.returns = node.returns
arg_decls.append(ret_init)

for i, arg in enumerate(args.args):
if isinstance(self.func.arguments[i], ti.template):
continue
Expand All @@ -620,6 +629,7 @@ def visit_FunctionDef(self, node):
arg_decls.append(arg_init)
# remove original args
node.args.args = []

else: # ti.func
for decorator in node.decorator_list:
if (isinstance(decorator, ast.Attribute)
Expand All @@ -640,8 +650,10 @@ def visit_FunctionDef(self, node):
'_by_value__')
args.args[i].arg += '_by_value__'
arg_decls.append(arg_init)

with self.variable_scope():
self.generic_visit(node)

node.body = arg_decls + node.body
return node

Expand Down Expand Up @@ -736,7 +748,15 @@ def visit_Assert(self, node):
def visit_Return(self, node):
self.generic_visit(node)
if self.is_kernel:
raise TaichiSyntaxError(
'"return" not allowed in \'ti.kernel\'. Please walk around by storing the return result to a global variable.'
)
# todo: we only support return at the end of kernel, check this
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# todo: we only support return at the end of kernel, check this
# TODO: we only support return at the end of a kernel. The following code does a check for this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the following code check if the return statement is at the end of a kernel?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't for now, it's a TODO item.

if node.value is not None:
assert self.returns is not None, 'kernel with return value must be ' \
'annotated with return type, e.g. def func() -> ti.f32'
archibate marked this conversation as resolved.
Show resolved Hide resolved
ret_expr = self.parse_expr('ti.cast(ti.Expr(0), 0)')
ret_expr.args[0].args[0] = node.value
ret_expr.args[1] = self.returns
ret_stmt = self.parse_stmt(
'ti.core.create_kernel_return(ret.ptr)')
ret_stmt.value.args[0].value = ret_expr
return ret_stmt
return node
8 changes: 8 additions & 0 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,14 @@ class KernelGen : public IRVisitor {
const_stmt->short_name(), const_stmt->val[0].stringify());
}

void visit(KernelReturnStmt *stmt) override {
used.argument = true;
used.int64 = true;
emit("_args_{}_[0] = {};", // TD: correct idx, another buf
archibate marked this conversation as resolved.
Show resolved Hide resolved
"i64",//data_type_short_name(stmt->element_type()),
stmt->value->short_name());
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too bad, still using arg[0] for return. We want ret[0] instead.
Also pls update context. get_ret_int.

}

void visit(ArgLoadStmt *stmt) override {
const auto dt = opengl_data_type_name(stmt->element_type());
used.argument = true;
Expand Down
1 change: 1 addition & 0 deletions taichi/backends/opengl/opengl_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ GLSLLaunchGuard::~GLSLLaunchGuard() {
if (!iov[i].size)
continue;
void *p = impl->ssbo[i].map(); // 0, iov[i].size); // output
if (i == 0) TI_INFO("{} or {}", ((int *)p)[0], ((float *)p)[0]);
std::memcpy(iov[i].base, p, iov[i].size);
}
impl->ssbo.clear();
Expand Down
18 changes: 18 additions & 0 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,24 @@ void CodeGenLLVM::visit(ArgStoreStmt *stmt) {
}
}

void CodeGenLLVM::visit(KernelReturnStmt *stmt) {
if (stmt->is_ptr) {
TI_NOT_IMPLEMENTED
} else {
auto intermediate_bits = tlctx->get_data_type(stmt->value->ret_type.data_type)
->getPrimitiveSizeInBits();
llvm::Type *intermediate_type =
llvm::Type::getIntNTy(*llvm_context, intermediate_bits);
llvm::Type *dest_ty = tlctx->get_data_type<int64>();
auto extended = builder->CreateZExt(
builder->CreateBitCast(llvm_val[stmt->value], intermediate_type),
dest_ty);
//extended = llvm::ConstantFP::get(*llvm_context, llvm::APFloat(666.6));//
archibate marked this conversation as resolved.
Show resolved Hide resolved
builder->CreateCall(get_runtime_function("LLVMRuntime_store_result"),
{get_runtime(), extended});
}
}

void CodeGenLLVM::visit(LocalLoadStmt *stmt) {
TI_ASSERT(stmt->width() == 1);
llvm_val[stmt] = builder->CreateLoad(llvm_val[stmt->ptr[0].var]);
Expand Down
2 changes: 2 additions & 0 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(ArgStoreStmt *stmt) override;

void visit(KernelReturnStmt *stmt) override;

void visit(LocalLoadStmt *stmt) override;

void visit(LocalStoreStmt *stmt) override;
Expand Down
2 changes: 2 additions & 0 deletions taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ PER_STATEMENT(FrontendSNodeOpStmt) // activate, deactivate, append, clear
PER_STATEMENT(FrontendAssertStmt)
PER_STATEMENT(FrontendArgStoreStmt)
PER_STATEMENT(FrontendFuncDefStmt)
PER_STATEMENT(FrontendKernelReturnStmt)

// Middle-end statement

Expand All @@ -24,6 +25,7 @@ PER_STATEMENT(WhileControlStmt)
PER_STATEMENT(ContinueStmt)
PER_STATEMENT(FuncBodyStmt)
PER_STATEMENT(FuncCallStmt)
PER_STATEMENT(KernelReturnStmt)

PER_STATEMENT(ArgLoadStmt)
PER_STATEMENT(ExternalPtrStmt)
Expand Down
14 changes: 14 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,20 @@ class FrontendWhileStmt : public Stmt {
DEFINE_ACCEPT
};

class FrontendKernelReturnStmt : public Stmt {
public:
Expr value;

FrontendKernelReturnStmt(const Expr &value) : value(value) {
}

bool is_container_statement() const override {
return false;
}

DEFINE_ACCEPT
};

// Expressions

class ArgLoadExpression : public Expression {
Expand Down
16 changes: 16 additions & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,22 @@ class FuncCallStmt : public Stmt {
DEFINE_ACCEPT
};

class KernelReturnStmt : public Stmt {
public:
Stmt *value;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put multi-return in another PR.


KernelReturnStmt(Stmt *value) : value(value) {
TI_STMT_REG_FIELDS;
}

bool is_container_statement() const override {
return false;
}

TI_STMT_DEF_FIELDS(value);
DEFINE_ACCEPT
};

class WhileStmt : public Stmt {
public:
Stmt *mask;
Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,13 @@ void SNode::write_float(const std::vector<int> &I, float64 val) {
(*writer_kernel)();
}

// TODO: use kernel.get_ret_float instead
uint64 SNode::fetch_reader_result() {
uint64 ret;
auto arch = get_current_program().config.arch;
if (arch == Arch::cuda) {
// TODO: refactor
// XXX: what about unified memory?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// XXX: what about unified memory?
// We use a `memcpy_device_to_host` call here even if we have unified memory. This simplifies code. Also note that a unified memory (4KB) page fault is rather expensive for reading 4-8 bytes.

#if defined(TI_WITH_CUDA)
CUDADriver::get_instance().memcpy_device_to_host(
&ret, get_current_program().result_buffer, sizeof(uint64));
Expand All @@ -141,6 +143,7 @@ uint64 SNode::fetch_reader_result() {
return ret;
}

// TODO
archibate marked this conversation as resolved.
Show resolved Hide resolved
float64 SNode::read_float(const std::vector<int> &I) {
if (reader_kernel == nullptr) {
reader_kernel = &get_current_program().get_snode_reader(this);
Expand All @@ -159,6 +162,7 @@ float64 SNode::read_float(const std::vector<int> &I) {
}
}

// TODO
// for int32 and int64
void SNode::write_int(const std::vector<int> &I, int64 val) {
if (writer_kernel == nullptr) {
Expand Down
99 changes: 94 additions & 5 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "taichi/program/program.h"
#include "taichi/program/async_engine.h"
#include "taichi/codegen/codegen.h"
#include "taichi/backends/cuda/cuda_driver.h"

TLANG_NAMESPACE_BEGIN

Expand Down Expand Up @@ -111,13 +112,9 @@ void Kernel::set_arg_float(int i, float64 d) {
}
}

void Kernel::set_extra_arg_int(int i, int j, int32 d) {
program.context.extra_args[i][j] = d;
}

void Kernel::set_arg_int(int i, int64 d) {
TI_ASSERT_INFO(
args[i].is_nparray == false,
!args[i].is_nparray,
"Assigning scalar value to numpy array argument is not allowed");
auto dt = args[i].dt;
if (dt == DataType::i32) {
Expand Down Expand Up @@ -145,10 +142,97 @@ void Kernel::set_arg_int(int i, int64 d) {
}
}

// XXX: sync with snode.cpp: fetch_reader_result
archibate marked this conversation as resolved.
Show resolved Hide resolved
static uint64 fetch_result_uint64(int i)
{
uint64 ret;
auto arch = get_current_program().config.arch;
if (arch == Arch::cuda) {
// TODO: refactor
// XXX: what about unified memory?
#if defined(TI_WITH_CUDA)
CUDADriver::get_instance().memcpy_device_to_host(&ret,
(uint64 *)get_current_program().result_buffer + i,
sizeof(uint64));
#else
TI_NOT_IMPLEMENTED;
#endif
} else if (arch_is_cpu(arch)) {
ret = ((uint64 *)get_current_program().result_buffer)[i];
} else {
ret = get_current_program().context.get_arg_as_uint64(i);
}
return ret;
}

template <typename T>
static T fetch_result(int i) // TODO: move to Program::fetch_result
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to move fetch_result_uint64 and this to Program in this PR. The modification would not be too big I guess, since you already need a few get_current_program calls in fetch_result_uint64.

{
return taichi_union_cast_with_different_sizes<T>(fetch_result_uint64(i));
}

float64 Kernel::get_ret_float(int i) {
auto dt = rets[i].dt;
if (dt == DataType::f32) {
return (float64)fetch_result<float32>(i);
} else if (dt == DataType::f64) {
return (float64)fetch_result<float64>(i);
} else if (dt == DataType::i32) {
return (float64)fetch_result<int32>(i);
} else if (dt == DataType::i64) {
return (float64)fetch_result<int64>(i);
} else if (dt == DataType::i8) {
return (float64)fetch_result<int8>(i);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to TI_ERROR when people are trying to read an integer return value as float? They can use get_ret_int instead.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's okay since these are ABIs and only used by our python code, an end-user never call this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I trust your decision.

} else if (dt == DataType::i16) {
return (float64)fetch_result<int16>(i);
} else if (dt == DataType::u8) {
return (float64)fetch_result<uint8>(i);
} else if (dt == DataType::u16) {
return (float64)fetch_result<uint16>(i);
} else if (dt == DataType::u32) {
return (float64)fetch_result<uint32>(i);
} else if (dt == DataType::u64) {
return (float64)fetch_result<uint64>(i);
} else {
TI_NOT_IMPLEMENTED
}
}

int64 Kernel::get_ret_int(int i) {
auto dt = rets[i].dt;
if (dt == DataType::i32) {
return (int64)fetch_result<int32>(i);
} else if (dt == DataType::i64) {
return (int64)fetch_result<int64>(i);
} else if (dt == DataType::i8) {
return (int64)fetch_result<int8>(i);
} else if (dt == DataType::i16) {
return (int64)fetch_result<int16>(i);
} else if (dt == DataType::u8) {
return (int64)fetch_result<uint8>(i);
} else if (dt == DataType::u16) {
return (int64)fetch_result<uint16>(i);
} else if (dt == DataType::u32) {
return (int64)fetch_result<uint32>(i);
} else if (dt == DataType::u64) {
return (int64)fetch_result<uint64>(i);
} else if (dt == DataType::f32) {
return (int64)fetch_result<float32>(i);
} else if (dt == DataType::f64) {
return (int64)fetch_result<float64>(i);
} else {
TI_NOT_IMPLEMENTED
}
}

void Kernel::mark_arg_return_value(int i, bool is_return) {
args[i].is_return_value = is_return;
}

void Kernel::set_extra_arg_int(int i, int j, int32 d) {
program.context.extra_args[i][j] = d;
}

void Kernel::set_arg_nparray(int i, uint64 ptr, uint64 size) {
TI_ASSERT_INFO(args[i].is_nparray,
"Assigning numpy array to scalar argument is not allowed");
Expand All @@ -166,4 +250,9 @@ int Kernel::insert_arg(DataType dt, bool is_nparray) {
return args.size() - 1;
}

int Kernel::insert_ret(DataType dt) {
rets.push_back(Ret{dt});
return rets.size() - 1;
}

TLANG_NAMESPACE_END
Loading