Skip to content

Commit

Permalink
- More fixes to #34554
Browse files Browse the repository at this point in the history
  • Loading branch information
jczaja committed Sep 21, 2021
1 parent 57934f0 commit b10e795
Showing 1 changed file with 23 additions and 33 deletions.
56 changes: 23 additions & 33 deletions paddle/fluid/platform/mkldnn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,6 @@ class MKLDNNHandler {
const std::string& base_key)
: dev_ctx_(dev_ctx),
engine_(engine),
key_common_(base_key),
key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)) {
platform::MKLDNNDeviceContext::tls().log_lib_version();
}
Expand Down Expand Up @@ -789,7 +788,6 @@ class MKLDNNHandler {
protected:
const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_;
std::string key_common_;
std::string key_;
};

Expand Down Expand Up @@ -1371,42 +1369,34 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
// Conv PD has to be passed to Grad op that
// may be exxecuted by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_conv_pd = key_common_ + "@conv_pd";
const std::string key_conv_pd = key_ + "@conv_pd";

conv_pd_ = std::static_pointer_cast<typename forward_t::primitive_desc>(
dev_ctx_.GetBlob(key_conv_pd));

if (conv_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);

conv_pd_ = std::static_pointer_cast<typename forward_t::primitive_desc>(
dev_ctx_.GetBlob(key_conv_pd));
if (conv_pd_ == nullptr) {
mkldnn::memory::dims stride_dims = strides;
mkldnn::memory::dims dilations_dims = dilations;
auto mkldnn_paddings = ToMkldnnPadding(paddings);

auto conv_desc =
bias ? typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T,
src, weights, *bias, dst, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1])
: typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T,
src, weights, dst, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1]);

mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_activation, fuse_alpha, fuse_beta,
fuse_residual_conn, output_shift_scale, sum_scale);

conv_pd_.reset(new typename forward_t::primitive_desc(
conv_desc, conv_attr, engine));
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx_.SetBlob(key_conv_pd, conv_pd_);
}
mkldnn::memory::dims stride_dims = strides;
mkldnn::memory::dims dilations_dims = dilations;
auto mkldnn_paddings = ToMkldnnPadding(paddings);

auto conv_desc =
bias ? typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src,
weights, *bias, dst, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1])
: typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src,
weights, dst, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1]);

mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_activation, fuse_alpha, fuse_beta,
fuse_residual_conn, output_shift_scale, sum_scale);

conv_pd_.reset(
new typename forward_t::primitive_desc(conv_desc, conv_attr, engine));
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx_.SetBlob(key_conv_pd, conv_pd_);
}

return conv_pd_;
Expand Down

0 comments on commit b10e795

Please sign in to comment.