Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[oneDNN] Concat refactoring and disabling caching #35002

Merged
merged 18 commits into from
Aug 24, 2021

Conversation

jczaja
Copy link
Contributor

@jczaja jczaja commented Aug 18, 2021

PR types

Others

PR changes

OPs

Describe

This PR disabled oneDNN primitives caching by PaddlePaddle in favour of oneDNN its own caching mechanism. Also refactoring to Acquire API was made for concat

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@jakpiase jakpiase left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the initiative to change from factory to acquireAPI :)

class ConcatMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::concat> {
public:
ConcatMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ConcatMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
ConcatMKLDNNHandler(const framework::ExecutionContext& ctx,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

}

memory::data_type dt =
paddle::framework::ToMKLDNNDataType(inputs[0]->type());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
paddle::framework::ToMKLDNNDataType(inputs[0]->type());
framework::ToMKLDNNDataType(inputs[0]->type());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

srcs_md.reserve(inputs.size());

// Create memory descriptors for each of inputs
for (size_t i = 0; i < inputs.size(); i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (size_t i = 0; i < inputs.size(); i++) {
for (size_t i = 0; i < inputs.size(); ++i) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

// Create memory descriptors for each of inputs
for (size_t i = 0; i < inputs.size(); i++) {
const auto dims =
paddle::framework::vectorize<int64_t>(inputs[i]->dims());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
paddle::framework::vectorize<int64_t>(inputs[i]->dims());
framework::vectorize<int64_t>(inputs[i]->dims());

srcs_md.emplace_back(memory::desc(dims, dt, inputs[i]->format()));
}

auto dst_dims = paddle::framework::vectorize<int64_t>(output->dims());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto dst_dims = paddle::framework::vectorize<int64_t>(output->dims());
auto dst_dims = framework::vectorize<int64_t>(output->dims());

// we cannot use base AcquireForwardPrimitiveDescriptor
void AcquireForwardPrimitiveDescriptor(
const mkldnn::memory::desc& dst_md, const int concat_axis,
const std::vector<mkldnn::memory::desc>& srcs_md) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const std::vector<mkldnn::memory::desc>& srcs_md) {
const std::vector<memory::desc>& srcs_md) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

// (jczaja) concat oneDNN prim is not having .desc attribute so
// we cannot use base AcquireForwardPrimitiveDescriptor
void AcquireForwardPrimitiveDescriptor(
const mkldnn::memory::desc& dst_md, const int concat_axis,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const mkldnn::memory::desc& dst_md, const int concat_axis,
const memory::desc& dst_md, const int concat_axis,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

}

std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const framework::Tensor& input, int i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const framework::Tensor& input, int i) {
const Tensor& input, int i) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

template <typename T>
class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have forgotten about this one

@jczaja
Copy link
Contributor Author

jczaja commented Aug 19, 2021

@jakpiase Please continue your review

jakpiase
jakpiase previously approved these changes Aug 19, 2021
Copy link
Contributor

@jakpiase jakpiase left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I really appreciate your effort to unify our operators

Copy link
Contributor

@lidanqing-intel lidanqing-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@lidanqing-intel lidanqing-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jczaja jczaja merged commit d9c0f09 into PaddlePaddle:develop Aug 24, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants