-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PTen】Add variadic args kernel for PTen API to replace KernelContext #37942
Conversation
… unboxed_kernel
… unboxed_kernel
Thanks for your contribution! |
… unboxed_kernel
… unboxed_kernel
paddle/pten/core/kernel_factory.h
Outdated
@@ -244,6 +252,7 @@ class Kernel { | |||
|
|||
private: | |||
KernelFn fn_{nullptr}; | |||
void* variadic_args_kernel_fn_ = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名对齐?variadic_args_kernel_fn_ -> variadic_args_fn_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. 修改为variadic_fn_
paddle/pten/core/kernel_factory.h
Outdated
@@ -232,6 +232,14 @@ class Kernel { | |||
|
|||
void operator()(KernelContext* ctx) const { fn_(ctx); } | |||
|
|||
template <typename Fn> | |||
Fn get_kernel_fn() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
取值方法和成员名一致,get_kernel_fn -> variadic_args_fn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. 由于不完全是取值方法,修改为GetVariadicKernelFn
paddle/pten/core/kernel_factory.h
Outdated
return func; | ||
} | ||
|
||
void set_variadic_args_kernel_fn(void* fn) { variadic_args_kernel_fn_ = fn; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,设值方法为set+成员名
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
将成员变量variadic_fn_
改由构造函数进行赋值,set方法删除
@@ -224,12 +238,12 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> { | |||
template <typename T> | |||
struct KernelCallHelper<TypeTag<T>> { | |||
template <int dev_ctx_idx, int in_idx, int attr_idx, int out_idx> | |||
static void Compute(KernelContext* ctx, Args&... args) { | |||
static void Compute(KernelContext* ctx, DevCtx dev_ctx, Args&... args) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里改动的原因是什么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是为了将DeviceContext参数提取出来给可变参数kernel进行转换
paddle/pten/core/kernel_registry.h
Outdated
@@ -137,10 +141,12 @@ struct KernelRegistrar { | |||
DataType dtype, | |||
KernelArgsParseFn args_parse_fn, | |||
KernelArgsDefFn args_def_fn, | |||
KernelFn kernel_fn) { | |||
KernelFn kernel_fn, | |||
void* variadic_args_kernel_fn) { | |||
KernelName kernel_name(kernel_name_cstr); | |||
KernelKey kernel_key(backend, layout, dtype); | |||
Kernel kernel(kernel_fn); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里要不要限制,kernel时必须提供两个输入才能构造,更改构造函数是不是比如后续set更好
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. set方法已删除,构造函数改为两个参数输入
PR types
Performance optimization
PR changes
Others
Describe
PTen中使用可变参数模板函数作为注册kernel来代替原先使用
KernelContext
的kernel注册函数,以消除KernelContext
数据结构创建所带来的性能开销。如上图所示:经过
test_scale_benchmark
性能对比测试,使用了可变参数模板函数的新形式kernel(variadic_args_kernel_fn)调用性能相比原来使用KernelContext
的kernel调用性能提升了大约25%。同时可变参数模板函数相比于直接使用switch-case
的kernel调用性能相差不大(已接近理论性能最优),但代码更为简洁。