Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] llama add xpu support #8282

Merged
merged 9 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device


def add_start_docstrings(*docstr):
Expand Down Expand Up @@ -483,6 +484,16 @@ def main():
config.num_attention_heads % config.sep_parallel_degree == 0
), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}"

if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
try:
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401

LinearConfig.enable_accumulate_steps_opt()
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
except ImportError:
# It's OK, not use accumulate_steps optimization
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是做什么的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XPU针对accumulate_steps > 1的场景进行优化,配合下面的paddle_xpu里面的Linear层进行使用


print("Final pre-training config:", config)

# Set the dtype for loading model
Expand Down
49 changes: 49 additions & 0 deletions paddlenlp/transformers/linear_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.distributed.fleet.meta_parallel import (
ColumnParallelLinear,
RowParallelLinear,
)
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ColumnSequenceParallelLinear,
RowSequenceParallelLinear,
)
from paddle.nn import Linear

dynamicheart marked this conversation as resolved.
Show resolved Hide resolved
from paddlenlp.utils.tools import get_env_device

from .mc2_parallel_linear import MC2ColumnSeqParallelLinear, MC2RowSeqParallelLinear

if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None:
ColumnSequenceParallelLinear = MC2ColumnSeqParallelLinear # noqa: F811
dynamicheart marked this conversation as resolved.
Show resolved Hide resolved
RowSequenceParallelLinear = MC2RowSeqParallelLinear # noqa: F811

Check warning on line 31 in paddlenlp/transformers/linear_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/linear_utils.py#L30-L31

Added lines #L30 - L31 were not covered by tests
elif get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import ColumnParallelLinear as XPUColumnParallelLinear
from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401
from paddle_xpu.layers.nn import RowParallelLinear as XPURowParallelLinear
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401

Check warning on line 37 in paddlenlp/transformers/linear_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/linear_utils.py#L33-L37

Added lines #L33 - L37 were not covered by tests
XPUColumnSequenceParallelLinear,
XPURowSequenceParallelLinear,
)

Linear = XPULinear # noqa: F811
ColumnParallelLinear = XPUColumnParallelLinear # noqa: F811
RowParallelLinear = XPURowParallelLinear # noqa: F811
ColumnSequenceParallelLinear = XPUColumnSequenceParallelLinear # noqa: F811
RowSequenceParallelLinear = XPURowSequenceParallelLinear # noqa: F811
except ImportError:

Check warning on line 47 in paddlenlp/transformers/linear_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/linear_utils.py#L42-L47

Added lines #L42 - L47 were not covered by tests
# It's OK, just use paddle's Linear layers
pass

Check warning on line 49 in paddlenlp/transformers/linear_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/linear_utils.py#L49

Added line #L49 was not covered by tests
dynamicheart marked this conversation as resolved.
Show resolved Hide resolved
98 changes: 65 additions & 33 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@
init_name_mappings,
)
from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies
from paddlenlp.transformers.mc2_parallel_linear import (
MC2ColumnSeqParallelLinear,
MC2RowSeqParallelLinear,
)
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand All @@ -74,6 +70,8 @@
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device

from .. import linear_utils
from ..linear_utils import Linear
from ..segment_parallel_utils import ReshardLayer
from .configuration import (
LLAMA_PRETRAINED_INIT_CONFIGURATION,
Expand Down Expand Up @@ -410,6 +408,15 @@
if self.config.use_fused_rms_norm:
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0]
elif get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

Check warning on line 413 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L411-L413

Added lines #L411 - L413 were not covered by tests

return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
except ImportError:
raise NotImplementedError(

Check warning on line 417 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L415-L417

Added lines #L415 - L417 were not covered by tests
f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature"
)
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)

if paddle.in_dynamic_mode():
Expand Down Expand Up @@ -571,15 +578,11 @@
self.fuse_attention_ffn = config.fuse_attention_ffn

if config.sequence_parallel:
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None:
ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

Check warning on line 582 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L581-L582

Added lines #L581 - L582 were not covered by tests
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

if config.tensor_parallel_degree > 1:
if config.fuse_attention_ffn:
Expand Down Expand Up @@ -611,15 +614,29 @@
)
else:
if config.fuse_attention_ffn:
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)

Check warning on line 617 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L617

Added line #L617 was not covered by tests
else:
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)

self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False)

def forward(self, x):
if self.fuse_attention_ffn:
# FIXME(yangjianbang): use paddle's native swiglu
if get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

Check warning on line 629 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L627-L629

Added lines #L627 - L629 were not covered by tests

out = self.gate_up_fused_proj(x)
out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True)
out = self.down_proj(out)
return out
except ImportError:
gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1)
out = self.down_proj(F.silu(gate_out) * up_out)
return out

Check warning on line 638 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L631-L638

Added lines #L631 - L638 were not covered by tests

x = swiglu(self.gate_up_fused_proj(x))
else:
x = swiglu(self.gate_proj(x), self.up_proj(x))
Expand Down Expand Up @@ -680,7 +697,7 @@
)

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope and get_env_device() != "npu":
if self.use_fused_rope and get_env_device() not in ["npu", "xpu"]:
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
warnings.warn(
"Enable fuse rope in the config, but fuse rope is not available. "
Expand All @@ -689,15 +706,11 @@
self.use_fused_rope = False

if config.sequence_parallel:
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None:
ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

Check warning on line 710 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L709-L710

Added lines #L709 - L710 were not covered by tests
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

if config.tensor_parallel_degree > 1:
if self.fuse_attention_qkv:
Expand Down Expand Up @@ -728,36 +741,36 @@
gather_output=False,
)
else:
self.k_proj = nn.Linear(
self.k_proj = Linear(

Check warning on line 744 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L744

Added line #L744 was not covered by tests
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(

Check warning on line 749 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L749

Added line #L749 was not covered by tests
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)

else:
if self.fuse_attention_qkv:
self.qkv_proj = nn.Linear(
self.qkv_proj = Linear(

Check warning on line 757 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L757

Added line #L757 was not covered by tests
self.hidden_size,
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
else:
self.q_proj = nn.Linear(
self.q_proj = Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
)
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
Expand All @@ -771,7 +784,7 @@
input_is_parallel=True,
)
else:
self.o_proj = nn.Linear(
self.o_proj = Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
Expand Down Expand Up @@ -1419,6 +1432,11 @@
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16")
expanded_attn_mask = expanded_attn_mask.astype("float16")
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
elif get_env_device() == "xpu":
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
expanded_attn_mask = expanded_attn_mask.astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)

Check warning on line 1439 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1436-L1439

Added lines #L1436 - L1439 were not covered by tests
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当传入的xy是整型scalar类型时,paddle.where 会将其视为int64、形状[1]的tensor,并会进行broadcast_add操作,详见search.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里和上面 npu 的逻辑看着差不多,可以复用吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上是可以复用的,但是npu里面写死了dtype是float16,xpu跑的程序是可能是float16,也可能是bfloat16的。我们需要修改npu的模块么?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SylarTiaNII 看一下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据 @wuhuachaocoding 意见,还是分成if elif两个单独的分支

else:
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
return expanded_attn_mask
Expand Down Expand Up @@ -1698,6 +1716,15 @@
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False
if self.weight.is_distributed:
self.weight.split_axis = 1
if get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import ( # noqa: F401

Check warning on line 1721 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1720-L1721

Added lines #L1720 - L1721 were not covered by tests
parallel_matmul as xpu_parallel_matmul,
)

self.xpu_parallel_matmul = xpu_parallel_matmul()
except ImportError:
self.xpu_parallel_matmul = None

Check warning on line 1727 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1725-L1727

Added lines #L1725 - L1727 were not covered by tests

def forward(self, hidden_states, tensor_parallel_output=None):
if self.config.sequence_parallel:
Expand All @@ -1711,7 +1738,12 @@
if tensor_parallel_output is None:
tensor_parallel_output = self.config.tensor_parallel_output

logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None:
logits = self.xpu_parallel_matmul(

Check warning on line 1742 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1742

Added line #L1742 was not covered by tests
hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training
Comment on lines +1742 to +1743
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

training 参数必须要吗?如果参数能一样的话,是不是 把 parallel_matmul 的实现在xpu下替换就好了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里面有两个原因:

  • XPU的一个优化是需要将parallel_matmul作为一个对象来存储某些状态
  • XPU需要training信息来进行优化

)
else:
logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
return logits


Expand Down
Loading