From ad28425c30e4aeb0f3bfc8400daf07db5be291d6 Mon Sep 17 00:00:00 2001 From: SylarTiaNII <15840554235@163.com> Date: Fri, 2 Feb 2024 20:03:32 +0800 Subject: [PATCH] [Distributed] fix sharding v2 on npu device --- .../distributed/fleet/utils/tensor_fusion_helper.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index ba2f4fb2cc016..3c79671e6d9a7 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -200,9 +200,12 @@ def _share_param_buffer(self): stop_gradient = self._param.stop_gradient self._param.stop_gradient = True self._param.flatten_() - self._param_buffer[ - self._index : self._index + self._param._numel() - ] = self._param + paddle.assign( + self._param, + self._param_buffer._slice( + self._index, self._index + self._param._numel() + ), + ) self._param.get_tensor()._set_dims(param_shape) self._param.stop_gradient = stop_gradient self._param_buffer._slice(