-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Lang] [IR] Kernel scalar return
support (ArgStoreStmt
-> KernelReturnStmt
)
#917
Changes from 13 commits
afc111b
7028981
00ca569
b275a0b
8e75e66
3344f6a
619f778
5d11c0d
69fb032
b6f2373
f665d23
9e2d15a
d2021e8
4ba3fc2
c45850e
b63f84c
2762aee
a6f20aa
8187da8
4c21997
f4bf30a
e24f34a
badbe2b
14de2ef
e53d1dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -498,6 +498,14 @@ class KernelGen : public IRVisitor { | |
const_stmt->short_name(), const_stmt->val[0].stringify()); | ||
} | ||
|
||
void visit(KernelReturnStmt *stmt) override { | ||
used.argument = true; | ||
used.int64 = true; | ||
emit("_args_{}_[0] = {};", // TD: correct idx, another buf | ||
archibate marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"i64",//data_type_short_name(stmt->element_type()), | ||
stmt->value->short_name()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Too bad, still using |
||
} | ||
|
||
void visit(ArgLoadStmt *stmt) override { | ||
const auto dt = opengl_data_type_name(stmt->element_type()); | ||
used.argument = true; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1453,6 +1453,22 @@ class FuncCallStmt : public Stmt { | |
DEFINE_ACCEPT | ||
}; | ||
|
||
class KernelReturnStmt : public Stmt { | ||
public: | ||
Stmt *value; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's put multi-return in another PR. |
||
|
||
KernelReturnStmt(Stmt *value) : value(value) { | ||
TI_STMT_REG_FIELDS; | ||
} | ||
|
||
bool is_container_statement() const override { | ||
return false; | ||
} | ||
|
||
TI_STMT_DEF_FIELDS(value); | ||
DEFINE_ACCEPT | ||
}; | ||
|
||
class WhileStmt : public Stmt { | ||
public: | ||
Stmt *mask; | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -122,11 +122,13 @@ void SNode::write_float(const std::vector<int> &I, float64 val) { | |||||
(*writer_kernel)(); | ||||||
} | ||||||
|
||||||
// TODO: use kernel.get_ret_float instead | ||||||
uint64 SNode::fetch_reader_result() { | ||||||
uint64 ret; | ||||||
auto arch = get_current_program().config.arch; | ||||||
if (arch == Arch::cuda) { | ||||||
// TODO: refactor | ||||||
// XXX: what about unified memory? | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
#if defined(TI_WITH_CUDA) | ||||||
CUDADriver::get_instance().memcpy_device_to_host( | ||||||
&ret, get_current_program().result_buffer, sizeof(uint64)); | ||||||
|
@@ -141,6 +143,7 @@ uint64 SNode::fetch_reader_result() { | |||||
return ret; | ||||||
} | ||||||
|
||||||
// TODO | ||||||
archibate marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
float64 SNode::read_float(const std::vector<int> &I) { | ||||||
if (reader_kernel == nullptr) { | ||||||
reader_kernel = &get_current_program().get_snode_reader(this); | ||||||
|
@@ -159,6 +162,7 @@ float64 SNode::read_float(const std::vector<int> &I) { | |||||
} | ||||||
} | ||||||
|
||||||
// TODO | ||||||
// for int32 and int64 | ||||||
void SNode::write_int(const std::vector<int> &I, int64 val) { | ||||||
if (writer_kernel == nullptr) { | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
#include "taichi/program/program.h" | ||
#include "taichi/program/async_engine.h" | ||
#include "taichi/codegen/codegen.h" | ||
#include "taichi/backends/cuda/cuda_driver.h" | ||
|
||
TLANG_NAMESPACE_BEGIN | ||
|
||
|
@@ -111,13 +112,9 @@ void Kernel::set_arg_float(int i, float64 d) { | |
} | ||
} | ||
|
||
void Kernel::set_extra_arg_int(int i, int j, int32 d) { | ||
program.context.extra_args[i][j] = d; | ||
} | ||
|
||
void Kernel::set_arg_int(int i, int64 d) { | ||
TI_ASSERT_INFO( | ||
args[i].is_nparray == false, | ||
!args[i].is_nparray, | ||
"Assigning scalar value to numpy array argument is not allowed"); | ||
auto dt = args[i].dt; | ||
if (dt == DataType::i32) { | ||
|
@@ -145,10 +142,97 @@ void Kernel::set_arg_int(int i, int64 d) { | |
} | ||
} | ||
|
||
// XXX: sync with snode.cpp: fetch_reader_result | ||
archibate marked this conversation as resolved.
Show resolved
Hide resolved
|
||
static uint64 fetch_result_uint64(int i) | ||
{ | ||
uint64 ret; | ||
auto arch = get_current_program().config.arch; | ||
if (arch == Arch::cuda) { | ||
// TODO: refactor | ||
// XXX: what about unified memory? | ||
#if defined(TI_WITH_CUDA) | ||
CUDADriver::get_instance().memcpy_device_to_host(&ret, | ||
(uint64 *)get_current_program().result_buffer + i, | ||
sizeof(uint64)); | ||
#else | ||
TI_NOT_IMPLEMENTED; | ||
#endif | ||
} else if (arch_is_cpu(arch)) { | ||
ret = ((uint64 *)get_current_program().result_buffer)[i]; | ||
} else { | ||
ret = get_current_program().context.get_arg_as_uint64(i); | ||
} | ||
return ret; | ||
} | ||
|
||
template <typename T> | ||
static T fetch_result(int i) // TODO: move to Program::fetch_result | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be great to move |
||
{ | ||
return taichi_union_cast_with_different_sizes<T>(fetch_result_uint64(i)); | ||
} | ||
|
||
float64 Kernel::get_ret_float(int i) { | ||
auto dt = rets[i].dt; | ||
if (dt == DataType::f32) { | ||
return (float64)fetch_result<float32>(i); | ||
} else if (dt == DataType::f64) { | ||
return (float64)fetch_result<float64>(i); | ||
} else if (dt == DataType::i32) { | ||
return (float64)fetch_result<int32>(i); | ||
} else if (dt == DataType::i64) { | ||
return (float64)fetch_result<int64>(i); | ||
} else if (dt == DataType::i8) { | ||
return (float64)fetch_result<int8>(i); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be better to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's okay since these are ABIs and only used by our python code, an end-user never call this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. I trust your decision. |
||
} else if (dt == DataType::i16) { | ||
return (float64)fetch_result<int16>(i); | ||
} else if (dt == DataType::u8) { | ||
return (float64)fetch_result<uint8>(i); | ||
} else if (dt == DataType::u16) { | ||
return (float64)fetch_result<uint16>(i); | ||
} else if (dt == DataType::u32) { | ||
return (float64)fetch_result<uint32>(i); | ||
} else if (dt == DataType::u64) { | ||
return (float64)fetch_result<uint64>(i); | ||
} else { | ||
TI_NOT_IMPLEMENTED | ||
} | ||
} | ||
|
||
int64 Kernel::get_ret_int(int i) { | ||
auto dt = rets[i].dt; | ||
if (dt == DataType::i32) { | ||
return (int64)fetch_result<int32>(i); | ||
} else if (dt == DataType::i64) { | ||
return (int64)fetch_result<int64>(i); | ||
} else if (dt == DataType::i8) { | ||
return (int64)fetch_result<int8>(i); | ||
} else if (dt == DataType::i16) { | ||
return (int64)fetch_result<int16>(i); | ||
} else if (dt == DataType::u8) { | ||
return (int64)fetch_result<uint8>(i); | ||
} else if (dt == DataType::u16) { | ||
return (int64)fetch_result<uint16>(i); | ||
} else if (dt == DataType::u32) { | ||
return (int64)fetch_result<uint32>(i); | ||
} else if (dt == DataType::u64) { | ||
return (int64)fetch_result<uint64>(i); | ||
} else if (dt == DataType::f32) { | ||
return (int64)fetch_result<float32>(i); | ||
} else if (dt == DataType::f64) { | ||
return (int64)fetch_result<float64>(i); | ||
} else { | ||
TI_NOT_IMPLEMENTED | ||
} | ||
} | ||
|
||
void Kernel::mark_arg_return_value(int i, bool is_return) { | ||
args[i].is_return_value = is_return; | ||
} | ||
|
||
void Kernel::set_extra_arg_int(int i, int j, int32 d) { | ||
program.context.extra_args[i][j] = d; | ||
} | ||
|
||
void Kernel::set_arg_nparray(int i, uint64 ptr, uint64 size) { | ||
TI_ASSERT_INFO(args[i].is_nparray, | ||
"Assigning numpy array to scalar argument is not allowed"); | ||
|
@@ -166,4 +250,9 @@ int Kernel::insert_arg(DataType dt, bool is_nparray) { | |
return args.size() - 1; | ||
} | ||
|
||
int Kernel::insert_ret(DataType dt) { | ||
rets.push_back(Ret{dt}); | ||
return rets.size() - 1; | ||
} | ||
|
||
TLANG_NAMESPACE_END |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the following code check if the return statement is at the end of a kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't for now, it's a TODO item.