Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[oneDNN] candidate fix to #34554 #35884

Merged
merged 5 commits into from
Sep 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::CreateKey(dev_ctx, src_tz, src_dt,
ctx.InputName("Input") + ctx.InputName("Filter"));

const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false;
std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> src_memory_p;
Expand All @@ -721,6 +720,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// are merged/unified, this will disappear
auto key_tid = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);

const std::string key_conv_pd = key_tid + "@conv_pd";
auto prim_key = key_tid + "@conv_p";
auto dst_key = key_tid + "@dst_mem_p";
auto src_key = key_tid + "@src_mem_p";
Expand All @@ -731,12 +731,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_reorder_key = key_tid + "@src_mem_preorder_p";
auto residual_reorder_key = key_tid + "@residual_data_mem_preorder_p";

conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx.GetBlob(prim_key));
conv_pd =
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
dev_ctx.GetBlob(key_conv_pd));

auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();

if (conv_p == nullptr || !is_test) {
if (conv_pd == nullptr || !is_test) {
float fuse_alpha = ctx.Attr<float>("fuse_alpha");
float fuse_beta = ctx.Attr<float>("fuse_beta");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
Expand Down Expand Up @@ -946,7 +947,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}

// create convolution op primitive
auto scale_bias_key = key + "@scale_bias";
conv_p = handler->AcquireConvolution();
if (bias) {
const K* bias_data = bias->data<K>();
Expand Down Expand Up @@ -1000,13 +1000,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dev_ctx.GetBlob(weights_key));
dst_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key));
conv_pd =
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
dev_ctx.GetBlob(key_conv_pd));
if (conv_pd) {
handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx,
mkldnn_engine, key));
}
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx.GetBlob(prim_key));
handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx,
mkldnn_engine, key));

if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
Expand Down
56 changes: 23 additions & 33 deletions paddle/fluid/platform/mkldnn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,6 @@ class MKLDNNHandler {
const std::string& base_key)
: dev_ctx_(dev_ctx),
engine_(engine),
key_common_(base_key),
key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)) {
platform::MKLDNNDeviceContext::tls().log_lib_version();
}
Expand Down Expand Up @@ -789,7 +788,6 @@ class MKLDNNHandler {
protected:
const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_;
std::string key_common_;
std::string key_;
};

Expand Down Expand Up @@ -1371,42 +1369,34 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
// Conv PD has to be passed to Grad op that
// may be exxecuted by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_conv_pd = key_common_ + "@conv_pd";
const std::string key_conv_pd = key_ + "@conv_pd";

conv_pd_ = std::static_pointer_cast<typename forward_t::primitive_desc>(
dev_ctx_.GetBlob(key_conv_pd));

if (conv_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);

conv_pd_ = std::static_pointer_cast<typename forward_t::primitive_desc>(
dev_ctx_.GetBlob(key_conv_pd));
if (conv_pd_ == nullptr) {
mkldnn::memory::dims stride_dims = strides;
mkldnn::memory::dims dilations_dims = dilations;
auto mkldnn_paddings = ToMkldnnPadding(paddings);

auto conv_desc =
bias ? typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T,
src, weights, *bias, dst, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1])
: typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T,
src, weights, dst, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1]);

mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_activation, fuse_alpha, fuse_beta,
fuse_residual_conn, output_shift_scale, sum_scale);

conv_pd_.reset(new typename forward_t::primitive_desc(
conv_desc, conv_attr, engine));
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx_.SetBlob(key_conv_pd, conv_pd_);
}
mkldnn::memory::dims stride_dims = strides;
mkldnn::memory::dims dilations_dims = dilations;
auto mkldnn_paddings = ToMkldnnPadding(paddings);

auto conv_desc =
bias ? typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src,
weights, *bias, dst, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1])
: typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src,
weights, dst, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1]);

mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_activation, fuse_alpha, fuse_beta,
fuse_residual_conn, output_shift_scale, sum_scale);

conv_pd_.reset(
new typename forward_t::primitive_desc(conv_desc, conv_attr, engine));
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx_.SetBlob(key_conv_pd, conv_pd_);
}

return conv_pd_;
Expand Down