Skip to content

Commit

Permalink
fix for conv2D training error (#38938)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpiase authored Jan 17, 2022
1 parent 05c98ec commit 944ea43
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
18 changes: 16 additions & 2 deletions paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ class ConvMKLDNNHandlerT
auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target");
if (is_test && weights_mem_p) {
return weights_mem_p;
} else {
} else if (is_test) {
const K* filter_data = filter->data<K>();
auto weights_tz = framework::vectorize(filter->dims());
platform::GetGroupConvWeightsTz(weights_tz, groups);
Expand All @@ -626,6 +626,19 @@ class ConvMKLDNNHandlerT
user_src_md, this->fwd_pd_->weights_desc(),
platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test, {},
scale_data, mask);
} else {
const T* filter_data = filter->data<T>();
auto weights_tz = framework::vectorize(filter->dims());
platform::GetGroupConvWeightsTz(weights_tz, groups);

auto user_src_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(),
GetWeightsFormat(filter->format(), groups, is_conv3d));

return this->AcquireMemoryWithReorder(
user_src_md, this->fwd_pd_->weights_desc(),
platform::to_void_cast<T>(filter_data), "@weights_mem_p", is_test, {},
scale_data, mask);
}
}

Expand Down Expand Up @@ -1027,7 +1040,8 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, BF16,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16, float>);
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d, MKLDNN,
::paddle::platform::CPUPlace, FP32,
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/platform/mkldnn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,14 @@ class MKLDNNHandlerT {
if (bwd_pd_ == nullptr) {
return false;
} else {
if (std::is_same<TBackward_params, mkldnn_dummy_primitive>::value ==
false) {
const std::string key_bw_w_pd = key_ + "@bwd_w_pd";
bwd_w_pd_ =
std::static_pointer_cast<typename TBackward_params::primitive_desc>(
dev_ctx_.GetBlob(key_bw_w_pd));
}

// When BWD is cached then still we need to Get FWD PD
const std::string key_fpd = key_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def setUp(self):
self.init_fuse_residual()
self.init_data_type()
self.init_force_fp32_output()
self.init_infer_or_train()

self.conv2d_param = {
'stride': self.stride,
Expand Down Expand Up @@ -83,6 +84,9 @@ def setUp(self):
if self.input_type is not np.float32:
self.input = convert_float_to_uint16(self.input)

if self.weight_type is not np.float32:
self.filter = convert_float_to_uint16(self.filter)

self.inputs = {
'Input': self.input,
'Filter': OpTest.np_dtype_to_fluid_dtype(
Expand All @@ -105,6 +109,8 @@ def setUp(self):
'fuse_residual_connection': self.fuse_residual
}

self.init_additional_attrs()

def test_check_output(self):
self.check_output_with_place(core.CPUPlace())

Expand Down Expand Up @@ -141,6 +147,12 @@ def init_fuse_relu(self):
def init_fuse_residual(self):
self.fuse_residual = True

def init_infer_or_train(self):
self.weight_type = np.float32

def init_additional_attrs(self):
self.attrs['is_test'] = True


@OpTestTool.skip_if_not_cpu_bf16()
class TestConv2DWithGradBF16Op(TestConv2DBF16Op):
Expand All @@ -150,6 +162,12 @@ def init_fuse_relu(self):
def init_fuse_residual(self):
self.fuse_residual = None

def init_additional_attrs(self):
self.attrs['is_test'] = False

def init_infer_or_train(self):
self.weight_type = np.uint16

def test_check_grad(self):
dout = self.conv_output_float
x = self.inputs_fp32['Input']
Expand Down

0 comments on commit 944ea43

Please sign in to comment.