From 834649f5a9543beeb48226d3d0901eb3537fbe5b Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Wed, 31 Jul 2024 11:05:13 +0800 Subject: [PATCH] Cp 28 for xpu (#8812) * xpu use allgather (#8697) * xpu use allgather * xpu use allgather * fix xpu gather for unified ckpt (#8710) * bug fix (#8730) * [XPU] use allgather and fp32 multinomial for XPU (#8787) --------- Co-authored-by: houj04 <35131887+houj04@users.noreply.github.com> --- paddlenlp/generation/utils.py | 2 ++ paddlenlp/peft/lora/lora_model.py | 8 ++++++-- paddlenlp/trainer/plugins/unified_checkpoint.py | 13 ++++++++++--- paddlenlp/transformers/conversion_utils.py | 8 ++++++-- paddlenlp/utils/distributed.py | 2 +- 5 files changed, 25 insertions(+), 8 deletions(-) diff --git a/paddlenlp/generation/utils.py b/paddlenlp/generation/utils.py index 6f5061318ad2..2356c116c74e 100644 --- a/paddlenlp/generation/utils.py +++ b/paddlenlp/generation/utils.py @@ -1208,6 +1208,8 @@ def sample( probs = TopKProcess(probs, top_k, min_tokens_to_keep) if top_p is not None and top_p < 1.0: probs = TopPProcess(probs, top_p, min_tokens_to_keep) + if paddle.device.is_compiled_with_xpu(): + probs = paddle.cast(probs, "float32") # multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852 next_tokens = paddle.multinomial(probs) diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index bf69760ae69a..51b8b7ae36b4 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -56,9 +56,10 @@ class RowSequenceParallelLinear: load_state_dict, ) from ...transformers.utils import get_checkpoint_shard_files, weight_name_suffix -from ...utils.distributed import distributed_gather +from ...utils.distributed import distributed_allgather, distributed_gather from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME from ...utils.log import logger +from ...utils.tools import get_env_device from .lora_config import LoRAConfig from .lora_layers import ( ColumnParallelLoRALinear, @@ -301,7 +302,10 @@ def _merge_trainable_tensor_parallel(self, trainable_state_dict): for key in trainable_state_dict: tensor = trainable_state_dict[key] if key in trainable_name_action_mappings: - ret = distributed_gather(tensor, group=mp_group, offload=True) + if get_env_device() == "xpu": + ret = distributed_allgather(tensor, group=mp_group, offload=True) + else: + ret = distributed_gather(tensor, group=mp_group, offload=True) action = trainable_name_action_mappings[key] if key in self.lora_split_mapping and not self.lora_split_mapping[key] and "_scale" in key and is_dst: ret = paddle.to_tensor(ret) diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index a8e1199a59b8..db928cab4f8a 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -41,7 +41,7 @@ get_checkpoint_shard_files, is_safetensors_available, ) -from paddlenlp.utils.distributed import distributed_gather +from paddlenlp.utils.distributed import distributed_allgather, distributed_gather from paddlenlp.utils.env import ( LORA_WEIGHTS_NAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME, @@ -64,6 +64,7 @@ ) from paddlenlp.utils.log import logger from paddlenlp.utils.nested import nested_copy, nested_copy_place +from paddlenlp.utils.tools import get_env_device if is_safetensors_available(): # from safetensors import safe_open @@ -1747,7 +1748,10 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): key = filter_keys[i] tensor = state_dict[key] if key in tp_actions: - ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) + if get_env_device() == "xpu": + ret = distributed_allgather(tensor, group=tp_group, offload=False) + else: + ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) action = tp_actions.pop(key) tensor = action(ret) if is_dst else None else: @@ -1784,7 +1788,10 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys) if tensor.numel().item() == 1: tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded else: - ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) + if get_env_device() == "xpu": + ret = distributed_allgather(tensor, group=tp_group, offload=False) + else: + ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) action = tp_actions[model_key] tensor = action(ret) if is_dst else None else: diff --git a/paddlenlp/transformers/conversion_utils.py b/paddlenlp/transformers/conversion_utils.py index d4d098bb31cb..f457bf28e856 100644 --- a/paddlenlp/transformers/conversion_utils.py +++ b/paddlenlp/transformers/conversion_utils.py @@ -37,7 +37,7 @@ from paddle import Tensor from paddle.nn import Layer -from paddlenlp.utils.distributed import distributed_gather +from paddlenlp.utils.distributed import distributed_allgather, distributed_gather from paddlenlp.utils.env import CONFIG_NAME, PADDLE_WEIGHTS_NAME, PYTORCH_WEIGHTS_NAME from paddlenlp.utils.import_utils import ( is_package_available, @@ -46,6 +46,7 @@ ) from paddlenlp.utils.log import logger from paddlenlp.utils.serialization import load_torch +from paddlenlp.utils.tools import get_env_device if TYPE_CHECKING: from paddlenlp.transformers import PretrainedConfig, PretrainedModel @@ -1269,7 +1270,10 @@ def merge_tensor_parallel(cls, state_dict, config) -> None: for key in state_dict.keys(): tensor = state_dict[key] if key in name_action_mappings: - ret = distributed_gather(tensor, group=mp_group, offload=True) + if get_env_device() == "xpu": + ret = distributed_allgather(tensor, group=mp_group, offload=True) + else: + ret = distributed_gather(tensor, group=mp_group, offload=True) action = name_action_mappings.pop(key) tensor = action(ret) if is_dst else None else: diff --git a/paddlenlp/utils/distributed.py b/paddlenlp/utils/distributed.py index 4471a06b732d..9ccc7fbd63a8 100644 --- a/paddlenlp/utils/distributed.py +++ b/paddlenlp/utils/distributed.py @@ -214,7 +214,7 @@ def distributed_allgather(tensor: Any, group=None, offload=False): x.reshape_(origin_shape) else: - distributed.all_gather(output_tensors, tensor) + distributed.all_gather(output_tensors, tensor, group=group) return output_tensors