-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PHI] Sum op migration #46239
[PHI] Sum op migration #46239
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
… sum_operator_migration
@piotrekobi, @Silv3S can you review please? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a few minor comments but overall a very good job
LGTM 👍
AddOneDNNHandler(dnnl::engine engine, | ||
const phi::Place& cpu_place, | ||
const std::vector<const TensorBase*>& x, | ||
DenseTensor* z) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe change it to out
because z
is unintuitive
PADDLE_ENFORCE_EQ( | ||
dev_ctx.GetPlace().GetType() == phi::AllocationType::CPU, | ||
true, | ||
phi::errors::PreconditionNotMet("Operator DNNL Sum must use CPUPlace")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DNNL is deprecated, please replace with oneDNN.
I'd change Operator DNNL Sum
either to oneDNN Sum primitive
or oneDNN AddN kernel
.
namespace phi { | ||
namespace funcs { | ||
template <typename T> | ||
class AddOneDNNHandler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We usually name handlers after oneDNN primitive names. For example scale
or gelu
kernels use ActivationKernel, because oneDNN treats such linear operations as activations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So what name do you suggest instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SumOneDNNHandler
if (in_place) { | ||
dst_mem = srcs_mem[0]; | ||
} else { | ||
dst_mem = handler.AcquireDstMemory(out); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto dst_mem = in_place ? srcs_mem[0] : handler.AcquireDstMemory(out);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the unnecessary namespace names. Great job once again :D!
namespace funcs { | ||
template <typename T> | ||
class AddOneDNNHandler | ||
: public phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::sum> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
: public phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::sum> { | |
: public OneDNNHandlerNoCachingT<T, dnnl::sum> { |
const std::vector<const TensorBase*>& x, | ||
DenseTensor* z) | ||
|
||
: phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
: phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place), | |
: OneDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place), |
|
||
: phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place), | ||
num_inputs_(0) { | ||
auto dst_tz = phi::vectorize<int64_t>(z->dims()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto dst_tz = phi::vectorize<int64_t>(z->dims()); | |
auto dst_tz = vectorize<int64_t>(z->dims()); |
to_void_cast<T>(input_data)); | ||
} | ||
|
||
using phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::sum>::AcquireDstMemory; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::sum>::AcquireDstMemory; | |
using OneDNNHandlerNoCachingT<T, dnnl::sum>::AcquireDstMemory; |
const std::vector<const TensorBase*>& x, | ||
DenseTensor* out) { | ||
PADDLE_ENFORCE_EQ( | ||
dev_ctx.GetPlace().GetType() == phi::AllocationType::CPU, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dev_ctx.GetPlace().GetType() == phi::AllocationType::CPU, | |
dev_ctx.GetPlace().GetType() == AllocationType::CPU, |
PADDLE_ENFORCE_EQ( | ||
dev_ctx.GetPlace().GetType() == phi::AllocationType::CPU, | ||
true, | ||
phi::errors::PreconditionNotMet("Operator DNNL Sum must use CPUPlace")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
phi::errors::PreconditionNotMet("Operator DNNL Sum must use CPUPlace")); | |
errors::PreconditionNotMet("Operator DNNL Sum must use CPUPlace")); |
|
||
PADDLE_ENFORCE_NE(x.empty(), | ||
true, | ||
phi::errors::InvalidArgument("Input variable is empty.")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
phi::errors::InvalidArgument("Input variable is empty.")); | |
errors::InvalidArgument("Input variable is empty.")); |
… sum_operator_migration
… sum_operator_migration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@jczaja I think it is ready to be merged |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
你的PR已合入Paddle库,请关注后续测试结果。 |
* Sum kernel migrated to phi * Static cast added, file name changed * OneDNNGetDataType to uppercase * refactoring * AddOneDNNHandler changed to SumOneDNNHandler
* Sum kernel migrated to phi * Static cast added, file name changed * OneDNNGetDataType to uppercase * refactoring * AddOneDNNHandler changed to SumOneDNNHandler
PR types
Others
PR changes
Others
Describe
Migrate sum(add) operator to phi