diff --git a/dynamic/utils/hybrid_optimizer.py b/dynamic/utils/hybrid_optimizer.py index aa4eb41..65c7ee9 100644 --- a/dynamic/utils/hybrid_optimizer.py +++ b/dynamic/utils/hybrid_optimizer.py @@ -85,12 +85,20 @@ def _append_optimize_op(self, block, param_and_grad): if getattr(param_and_grad[0], 'is_sparse_grad', None): index = getattr(param_and_grad[0], 'index', None) axis = getattr(param_and_grad[0], 'axis', None) - _, _ = paddle._C_ops.sparse_momentum( - param_and_grad[0], param_and_grad[1], velocity_acc, index, lr, - param_and_grad[0], velocity_acc, 'mu', self._momentum, - 'use_nesterov', self._use_nesterov, 'regularization_method', - self._regularization_method, 'regularization_coeff', - self._regularization_coeff, 'axis', axis) + try: + _, _ = paddle._C_ops.sparse_momentum( + param_and_grad[0], param_and_grad[1], velocity_acc, index, lr, + param_and_grad[0], velocity_acc, 'mu', self._momentum, + 'use_nesterov', self._use_nesterov, 'regularization_method', + self._regularization_method, 'regularization_coeff', + self._regularization_coeff, 'axis', axis) + except: + _, _, _ = paddle._C_ops.sparse_momentum( + param_and_grad[0], param_and_grad[1], velocity_acc, index, lr, master_weight, + param_and_grad[0], velocity_acc, master_weight, 'mu', self._momentum, + 'use_nesterov', self._use_nesterov, 'regularization_method', + self._regularization_method, 'regularization_coeff', + self._regularization_coeff, 'axis', axis, 'multi_precision', find_master) else: _, _, _ = paddle._C_ops.momentum( param_and_grad[0], param_and_grad[1], velocity_acc, lr,