-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XPU] llama add xpu support #8282
Changes from 6 commits
1293619
e388ed6
41421f4
e9a4b87
2a8c639
d9dcdbe
40c23a5
a3935fd
6e0316a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddle.distributed.fleet.meta_parallel import ( | ||
ColumnParallelLinear, | ||
RowParallelLinear, | ||
) | ||
from paddle.distributed.fleet.utils.sequence_parallel_utils import ( | ||
ColumnSequenceParallelLinear, | ||
RowSequenceParallelLinear, | ||
) | ||
from paddle.nn import Linear | ||
|
||
dynamicheart marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from paddlenlp.utils.tools import get_env_device | ||
|
||
from .mc2_parallel_linear import MC2ColumnSeqParallelLinear, MC2RowSeqParallelLinear | ||
|
||
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None: | ||
ColumnSequenceParallelLinear = MC2ColumnSeqParallelLinear # noqa: F811 | ||
dynamicheart marked this conversation as resolved.
Show resolved
Hide resolved
|
||
RowSequenceParallelLinear = MC2RowSeqParallelLinear # noqa: F811 | ||
elif get_env_device() == "xpu": | ||
try: | ||
from paddle_xpu.layers.nn import ColumnParallelLinear as XPUColumnParallelLinear | ||
from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401 | ||
from paddle_xpu.layers.nn import RowParallelLinear as XPURowParallelLinear | ||
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 | ||
XPUColumnSequenceParallelLinear, | ||
XPURowSequenceParallelLinear, | ||
) | ||
|
||
Linear = XPULinear # noqa: F811 | ||
ColumnParallelLinear = XPUColumnParallelLinear # noqa: F811 | ||
RowParallelLinear = XPURowParallelLinear # noqa: F811 | ||
ColumnSequenceParallelLinear = XPUColumnSequenceParallelLinear # noqa: F811 | ||
RowSequenceParallelLinear = XPURowSequenceParallelLinear # noqa: F811 | ||
except ImportError: | ||
# It's OK, just use paddle's Linear layers | ||
pass | ||
dynamicheart marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,10 +62,6 @@ | |
init_name_mappings, | ||
) | ||
from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies | ||
from paddlenlp.transformers.mc2_parallel_linear import ( | ||
MC2ColumnSeqParallelLinear, | ||
MC2RowSeqParallelLinear, | ||
) | ||
from paddlenlp.transformers.model_outputs import ( | ||
BaseModelOutputWithPastAndCrossAttentions, | ||
CausalLMOutputWithCrossAttentions, | ||
|
@@ -74,6 +70,8 @@ | |
from paddlenlp.utils.log import logger | ||
from paddlenlp.utils.tools import get_env_device | ||
|
||
from .. import linear_utils | ||
from ..linear_utils import Linear | ||
from ..segment_parallel_utils import ReshardLayer | ||
from .configuration import ( | ||
LLAMA_PRETRAINED_INIT_CONFIGURATION, | ||
|
@@ -410,6 +408,15 @@ | |
if self.config.use_fused_rms_norm: | ||
if get_env_device() == "npu": | ||
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0] | ||
elif get_env_device() == "xpu": | ||
try: | ||
import paddle_xpu_nn # noqa: F821 | ||
|
||
return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] | ||
except ImportError: | ||
raise NotImplementedError( | ||
f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" | ||
) | ||
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) | ||
|
||
if paddle.in_dynamic_mode(): | ||
|
@@ -571,15 +578,11 @@ | |
self.fuse_attention_ffn = config.fuse_attention_ffn | ||
|
||
if config.sequence_parallel: | ||
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None: | ||
ColumnParallelLinear = MC2ColumnSeqParallelLinear | ||
RowParallelLinear = MC2RowSeqParallelLinear | ||
else: | ||
ColumnParallelLinear = ColumnSequenceParallelLinear | ||
RowParallelLinear = RowSequenceParallelLinear | ||
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear | ||
RowParallelLinear = linear_utils.RowSequenceParallelLinear | ||
else: | ||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear | ||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear | ||
ColumnParallelLinear = linear_utils.ColumnParallelLinear | ||
RowParallelLinear = linear_utils.RowParallelLinear | ||
|
||
if config.tensor_parallel_degree > 1: | ||
if config.fuse_attention_ffn: | ||
|
@@ -611,15 +614,29 @@ | |
) | ||
else: | ||
if config.fuse_attention_ffn: | ||
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) | ||
self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) | ||
else: | ||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) | ||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) | ||
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) | ||
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) | ||
|
||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) | ||
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) | ||
|
||
def forward(self, x): | ||
if self.fuse_attention_ffn: | ||
# FIXME(yangjianbang): use paddle's native swiglu | ||
if get_env_device() == "xpu": | ||
try: | ||
import paddle_xpu_nn # noqa: F821 | ||
|
||
out = self.gate_up_fused_proj(x) | ||
out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True) | ||
out = self.down_proj(out) | ||
return out | ||
except ImportError: | ||
gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1) | ||
out = self.down_proj(F.silu(gate_out) * up_out) | ||
return out | ||
|
||
x = swiglu(self.gate_up_fused_proj(x)) | ||
else: | ||
x = swiglu(self.gate_proj(x), self.up_proj(x)) | ||
|
@@ -680,7 +697,7 @@ | |
) | ||
|
||
self.use_fused_rope = config.use_fused_rope | ||
if self.use_fused_rope and get_env_device() != "npu": | ||
if self.use_fused_rope and get_env_device() not in ["npu", "xpu"]: | ||
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: | ||
warnings.warn( | ||
"Enable fuse rope in the config, but fuse rope is not available. " | ||
|
@@ -689,15 +706,11 @@ | |
self.use_fused_rope = False | ||
|
||
if config.sequence_parallel: | ||
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None: | ||
ColumnParallelLinear = MC2ColumnSeqParallelLinear | ||
RowParallelLinear = MC2RowSeqParallelLinear | ||
else: | ||
ColumnParallelLinear = ColumnSequenceParallelLinear | ||
RowParallelLinear = RowSequenceParallelLinear | ||
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear | ||
RowParallelLinear = linear_utils.RowSequenceParallelLinear | ||
else: | ||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear | ||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear | ||
ColumnParallelLinear = linear_utils.ColumnParallelLinear | ||
RowParallelLinear = linear_utils.RowParallelLinear | ||
|
||
if config.tensor_parallel_degree > 1: | ||
if self.fuse_attention_qkv: | ||
|
@@ -728,36 +741,36 @@ | |
gather_output=False, | ||
) | ||
else: | ||
self.k_proj = nn.Linear( | ||
self.k_proj = Linear( | ||
self.hidden_size, | ||
self.config.num_key_value_heads * self.head_dim, | ||
bias_attr=False, | ||
) | ||
self.v_proj = nn.Linear( | ||
self.v_proj = Linear( | ||
self.hidden_size, | ||
self.config.num_key_value_heads * self.head_dim, | ||
bias_attr=False, | ||
) | ||
|
||
else: | ||
if self.fuse_attention_qkv: | ||
self.qkv_proj = nn.Linear( | ||
self.qkv_proj = Linear( | ||
self.hidden_size, | ||
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim, | ||
bias_attr=False, | ||
) | ||
else: | ||
self.q_proj = nn.Linear( | ||
self.q_proj = Linear( | ||
self.hidden_size, | ||
self.hidden_size, | ||
bias_attr=False, | ||
) | ||
self.k_proj = nn.Linear( | ||
self.k_proj = Linear( | ||
self.hidden_size, | ||
self.config.num_key_value_heads * self.head_dim, | ||
bias_attr=False, | ||
) | ||
self.v_proj = nn.Linear( | ||
self.v_proj = Linear( | ||
self.hidden_size, | ||
self.config.num_key_value_heads * self.head_dim, | ||
bias_attr=False, | ||
|
@@ -771,7 +784,7 @@ | |
input_is_parallel=True, | ||
) | ||
else: | ||
self.o_proj = nn.Linear( | ||
self.o_proj = Linear( | ||
self.hidden_size, | ||
self.hidden_size, | ||
bias_attr=False, | ||
|
@@ -1419,6 +1432,11 @@ | |
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16") | ||
expanded_attn_mask = expanded_attn_mask.astype("float16") | ||
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) | ||
elif get_env_device() == "xpu": | ||
x = paddle.to_tensor(0.0, dtype=dtype) | ||
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype) | ||
expanded_attn_mask = expanded_attn_mask.astype(dtype) | ||
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 当传入的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里和上面 npu 的逻辑看着差不多,可以复用吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 理论上是可以复用的,但是npu里面写死了dtype是 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SylarTiaNII 看一下? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 根据 @wuhuachaocoding 意见,还是分成if elif两个单独的分支 |
||
else: | ||
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) | ||
return expanded_attn_mask | ||
|
@@ -1698,6 +1716,15 @@ | |
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False | ||
if self.weight.is_distributed: | ||
self.weight.split_axis = 1 | ||
if get_env_device() == "xpu": | ||
try: | ||
from paddle_xpu.layers.nn import ( # noqa: F401 | ||
parallel_matmul as xpu_parallel_matmul, | ||
) | ||
|
||
self.xpu_parallel_matmul = xpu_parallel_matmul() | ||
except ImportError: | ||
self.xpu_parallel_matmul = None | ||
|
||
def forward(self, hidden_states, tensor_parallel_output=None): | ||
if self.config.sequence_parallel: | ||
|
@@ -1711,7 +1738,12 @@ | |
if tensor_parallel_output is None: | ||
tensor_parallel_output = self.config.tensor_parallel_output | ||
|
||
logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) | ||
if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None: | ||
logits = self.xpu_parallel_matmul( | ||
hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training | ||
Comment on lines
+1742
to
+1743
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. training 参数必须要吗?如果参数能一样的话,是不是 把 parallel_matmul 的实现在xpu下替换就好了? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里面有两个原因:
|
||
) | ||
else: | ||
logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) | ||
return logits | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是做什么的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XPU针对accumulate_steps > 1的场景进行优化,配合下面的paddle_xpu里面的Linear层进行使用