-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[oneDNN] Concat refactoring and disabling caching #35002
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like the initiative to change from factory to acquireAPI :)
class ConcatMKLDNNHandler | ||
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::concat> { | ||
public: | ||
ConcatMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConcatMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, | |
ConcatMKLDNNHandler(const framework::ExecutionContext& ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
} | ||
|
||
memory::data_type dt = | ||
paddle::framework::ToMKLDNNDataType(inputs[0]->type()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle::framework::ToMKLDNNDataType(inputs[0]->type()); | |
framework::ToMKLDNNDataType(inputs[0]->type()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
srcs_md.reserve(inputs.size()); | ||
|
||
// Create memory descriptors for each of inputs | ||
for (size_t i = 0; i < inputs.size(); i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (size_t i = 0; i < inputs.size(); i++) { | |
for (size_t i = 0; i < inputs.size(); ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
// Create memory descriptors for each of inputs | ||
for (size_t i = 0; i < inputs.size(); i++) { | ||
const auto dims = | ||
paddle::framework::vectorize<int64_t>(inputs[i]->dims()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle::framework::vectorize<int64_t>(inputs[i]->dims()); | |
framework::vectorize<int64_t>(inputs[i]->dims()); |
srcs_md.emplace_back(memory::desc(dims, dt, inputs[i]->format())); | ||
} | ||
|
||
auto dst_dims = paddle::framework::vectorize<int64_t>(output->dims()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto dst_dims = paddle::framework::vectorize<int64_t>(output->dims()); | |
auto dst_dims = framework::vectorize<int64_t>(output->dims()); |
// we cannot use base AcquireForwardPrimitiveDescriptor | ||
void AcquireForwardPrimitiveDescriptor( | ||
const mkldnn::memory::desc& dst_md, const int concat_axis, | ||
const std::vector<mkldnn::memory::desc>& srcs_md) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const std::vector<mkldnn::memory::desc>& srcs_md) { | |
const std::vector<memory::desc>& srcs_md) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
// (jczaja) concat oneDNN prim is not having .desc attribute so | ||
// we cannot use base AcquireForwardPrimitiveDescriptor | ||
void AcquireForwardPrimitiveDescriptor( | ||
const mkldnn::memory::desc& dst_md, const int concat_axis, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const mkldnn::memory::desc& dst_md, const int concat_axis, | |
const memory::desc& dst_md, const int concat_axis, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
} | ||
|
||
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( | ||
const framework::Tensor& input, int i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const framework::Tensor& input, int i) { | |
const Tensor& input, int i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
template <typename T> | ||
class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { | ||
public: | ||
void Compute(const paddle::framework::ExecutionContext& ctx) const override { | ||
auto& dev_ctx = | ||
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); | |
ctx.template device_context<platform::MKLDNNDeviceContext>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have forgotten about this one
@jakpiase Please continue your review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I really appreciate your effort to unify our operators
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
bc89f59
to
4ad092a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
OPs
Describe
This PR disabled oneDNN primitives caching by PaddlePaddle in favour of oneDNN its own caching mechanism. Also refactoring to Acquire API was made for concat