Skip to content

Commit

Permalink
Cp 28 for xpu (#8812)
Browse files Browse the repository at this point in the history
* xpu use allgather (#8697)

* xpu use allgather

* xpu use allgather

* fix xpu gather for unified ckpt (#8710)

* bug fix (#8730)

* [XPU] use allgather and fp32 multinomial for XPU (#8787)

---------

Co-authored-by: houj04 <35131887+houj04@users.noreply.github.com>
  • Loading branch information
FeixLiu and houj04 authored Jul 31, 2024
1 parent a554c48 commit 834649f
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 8 deletions.
2 changes: 2 additions & 0 deletions paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,8 @@ def sample(
probs = TopKProcess(probs, top_k, min_tokens_to_keep)
if top_p is not None and top_p < 1.0:
probs = TopPProcess(probs, top_p, min_tokens_to_keep)
if paddle.device.is_compiled_with_xpu():
probs = paddle.cast(probs, "float32")

# multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
next_tokens = paddle.multinomial(probs)
Expand Down
8 changes: 6 additions & 2 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ class RowSequenceParallelLinear:
load_state_dict,
)
from ...transformers.utils import get_checkpoint_shard_files, weight_name_suffix
from ...utils.distributed import distributed_gather
from ...utils.distributed import distributed_allgather, distributed_gather
from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME
from ...utils.log import logger
from ...utils.tools import get_env_device
from .lora_config import LoRAConfig
from .lora_layers import (
ColumnParallelLoRALinear,
Expand Down Expand Up @@ -301,7 +302,10 @@ def _merge_trainable_tensor_parallel(self, trainable_state_dict):
for key in trainable_state_dict:
tensor = trainable_state_dict[key]
if key in trainable_name_action_mappings:
ret = distributed_gather(tensor, group=mp_group, offload=True)
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=mp_group, offload=True)
else:
ret = distributed_gather(tensor, group=mp_group, offload=True)
action = trainable_name_action_mappings[key]
if key in self.lora_split_mapping and not self.lora_split_mapping[key] and "_scale" in key and is_dst:
ret = paddle.to_tensor(ret)
Expand Down
13 changes: 10 additions & 3 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
get_checkpoint_shard_files,
is_safetensors_available,
)
from paddlenlp.utils.distributed import distributed_gather
from paddlenlp.utils.distributed import distributed_allgather, distributed_gather
from paddlenlp.utils.env import (
LORA_WEIGHTS_NAME,
PADDLE_MASTER_WEIGHTS_INDEX_NAME,
Expand All @@ -64,6 +64,7 @@
)
from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import nested_copy, nested_copy_place
from paddlenlp.utils.tools import get_env_device

if is_safetensors_available():
# from safetensors import safe_open
Expand Down Expand Up @@ -1747,7 +1748,10 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
key = filter_keys[i]
tensor = state_dict[key]
if key in tp_actions:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=tp_group, offload=False)
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
action = tp_actions.pop(key)
tensor = action(ret) if is_dst else None
else:
Expand Down Expand Up @@ -1784,7 +1788,10 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
if tensor.numel().item() == 1:
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=tp_group, offload=False)
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
action = tp_actions[model_key]
tensor = action(ret) if is_dst else None
else:
Expand Down
8 changes: 6 additions & 2 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from paddle import Tensor
from paddle.nn import Layer

from paddlenlp.utils.distributed import distributed_gather
from paddlenlp.utils.distributed import distributed_allgather, distributed_gather
from paddlenlp.utils.env import CONFIG_NAME, PADDLE_WEIGHTS_NAME, PYTORCH_WEIGHTS_NAME
from paddlenlp.utils.import_utils import (
is_package_available,
Expand All @@ -46,6 +46,7 @@
)
from paddlenlp.utils.log import logger
from paddlenlp.utils.serialization import load_torch
from paddlenlp.utils.tools import get_env_device

if TYPE_CHECKING:
from paddlenlp.transformers import PretrainedConfig, PretrainedModel
Expand Down Expand Up @@ -1269,7 +1270,10 @@ def merge_tensor_parallel(cls, state_dict, config) -> None:
for key in state_dict.keys():
tensor = state_dict[key]
if key in name_action_mappings:
ret = distributed_gather(tensor, group=mp_group, offload=True)
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=mp_group, offload=True)
else:
ret = distributed_gather(tensor, group=mp_group, offload=True)
action = name_action_mappings.pop(key)
tensor = action(ret) if is_dst else None
else:
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def distributed_allgather(tensor: Any, group=None, offload=False):
x.reshape_(origin_shape)

else:
distributed.all_gather(output_tensors, tensor)
distributed.all_gather(output_tensors, tensor, group=group)

return output_tensors

Expand Down

0 comments on commit 834649f

Please sign in to comment.