From e4786493e5366c8103b63d1a5d67e235fb1666d2 Mon Sep 17 00:00:00 2001 From: ronny1996 Date: Mon, 15 May 2023 10:17:03 +0000 Subject: [PATCH] [CustomDevice] fix BatchNorm --- python/paddle/nn/layer/norm.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 9facd8e917273..eef98d96632fa 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -975,6 +975,29 @@ def __init__( ) self._variance.stop_gradient = True + # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op + if ( + _global_flags()['FLAGS_npu_storage_format'] + and 'npu' in get_all_custom_device_type() + ): + with no_grad(): + weight_trans = _C_ops.npu_identity( + self.weight, 3 + ) # ACL_FORMAT_NC1HWC0 = 3 + bias_trans = _C_ops.npu_identity( + self.bias, 3 + ) # ACL_FORMAT_NC1HWC0 = 3 + mean_trans = _C_ops.npu_identity( + self._mean, 3 + ) # ACL_FORMAT_NC1HWC0 = 3 + var_trans = _C_ops.npu_identity( + self._variance, 3 + ) # ACL_FORMAT_NC1HWC0 = 3 + weight_trans._share_underline_tensor_to(self.weight) + bias_trans._share_underline_tensor_to(self.bias) + mean_trans._share_underline_tensor_to(self._mean) + var_trans._share_underline_tensor_to(self._variance) + self._in_place = in_place self._data_layout = data_layout self._momentum = momentum