From 944ea4368428b73d68172603cf041e997e08ab25 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Mon, 17 Jan 2022 10:27:17 +0100 Subject: [PATCH] fix for conv2D training error (#38938) --- .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 18 ++++++++++++++++-- paddle/fluid/platform/mkldnn_reuse.h | 8 ++++++++ .../mkldnn/test_conv2d_bf16_mkldnn_op.py | 18 ++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 0526ae52b3903..44289015bc7c4 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -613,7 +613,7 @@ class ConvMKLDNNHandlerT auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target"); if (is_test && weights_mem_p) { return weights_mem_p; - } else { + } else if (is_test) { const K* filter_data = filter->data(); auto weights_tz = framework::vectorize(filter->dims()); platform::GetGroupConvWeightsTz(weights_tz, groups); @@ -626,6 +626,19 @@ class ConvMKLDNNHandlerT user_src_md, this->fwd_pd_->weights_desc(), platform::to_void_cast(filter_data), "@weights_mem_p", is_test, {}, scale_data, mask); + } else { + const T* filter_data = filter->data(); + auto weights_tz = framework::vectorize(filter->dims()); + platform::GetGroupConvWeightsTz(weights_tz, groups); + + auto user_src_md = platform::MKLDNNMemDesc( + weights_tz, platform::MKLDNNGetDataType(), + GetWeightsFormat(filter->format(), groups, is_conv3d)); + + return this->AcquireMemoryWithReorder( + user_src_md, this->fwd_pd_->weights_desc(), + platform::to_void_cast(filter_data), "@weights_mem_p", is_test, {}, + scale_data, mask); } } @@ -1027,7 +1040,8 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, BF16, ops::kConvMKLDNNFP32, - ops::ConvMKLDNNGradOpKernel); + ops::ConvMKLDNNGradOpKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d, MKLDNN, ::paddle::platform::CPUPlace, FP32, diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index ef216e48416f9..9aadd36c2e8ac 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -377,6 +377,14 @@ class MKLDNNHandlerT { if (bwd_pd_ == nullptr) { return false; } else { + if (std::is_same::value == + false) { + const std::string key_bw_w_pd = key_ + "@bwd_w_pd"; + bwd_w_pd_ = + std::static_pointer_cast( + dev_ctx_.GetBlob(key_bw_w_pd)); + } + // When BWD is cached then still we need to Get FWD PD const std::string key_fpd = key_ + "@fwd_pd"; fwd_pd_ = std::static_pointer_cast( diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py index 4c753da0512f8..702d26b073b6b 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py @@ -50,6 +50,7 @@ def setUp(self): self.init_fuse_residual() self.init_data_type() self.init_force_fp32_output() + self.init_infer_or_train() self.conv2d_param = { 'stride': self.stride, @@ -83,6 +84,9 @@ def setUp(self): if self.input_type is not np.float32: self.input = convert_float_to_uint16(self.input) + if self.weight_type is not np.float32: + self.filter = convert_float_to_uint16(self.filter) + self.inputs = { 'Input': self.input, 'Filter': OpTest.np_dtype_to_fluid_dtype( @@ -105,6 +109,8 @@ def setUp(self): 'fuse_residual_connection': self.fuse_residual } + self.init_additional_attrs() + def test_check_output(self): self.check_output_with_place(core.CPUPlace()) @@ -141,6 +147,12 @@ def init_fuse_relu(self): def init_fuse_residual(self): self.fuse_residual = True + def init_infer_or_train(self): + self.weight_type = np.float32 + + def init_additional_attrs(self): + self.attrs['is_test'] = True + @OpTestTool.skip_if_not_cpu_bf16() class TestConv2DWithGradBF16Op(TestConv2DBF16Op): @@ -150,6 +162,12 @@ def init_fuse_relu(self): def init_fuse_residual(self): self.fuse_residual = None + def init_additional_attrs(self): + self.attrs['is_test'] = False + + def init_infer_or_train(self): + self.weight_type = np.uint16 + def test_check_grad(self): dout = self.conv_output_float x = self.inputs_fp32['Input']