Skip to content

Commit

Permalink
[XPU] llama add xpu support (#8282)
Browse files Browse the repository at this point in the history
* [XPU] llama add xpu support

* fix

* use try import

* fix

* refine

* refine

* refine

* refine
  • Loading branch information
dynamicheart authored Apr 29, 2024
1 parent 62b58a3 commit ba9d9bd
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 33 deletions.
11 changes: 11 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device


def add_start_docstrings(*docstr):
Expand Down Expand Up @@ -490,6 +491,16 @@ def main():
config.num_attention_heads % config.sep_parallel_degree == 0
), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}"

if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
try:
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401

LinearConfig.enable_accumulate_steps_opt()
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
except ImportError:
# It's OK, not use accumulate_steps optimization
pass

print("Final pre-training config:", config)

# Set the dtype for loading model
Expand Down
59 changes: 59 additions & 0 deletions paddlenlp/transformers/linear_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This file is used for replacing Paddle's native Linear implementations with vendors' customized implementations
"""

import paddle.distributed.fleet.meta_parallel as mpu
from paddle import nn
from paddle.distributed.fleet.utils import sequence_parallel_utils

from paddlenlp.transformers.mc2_parallel_linear import (
MC2ColumnSeqParallelLinear,
MC2RowSeqParallelLinear,
)
from paddlenlp.utils.tools import get_env_device

Linear = nn.Linear
ColumnParallelLinear = mpu.ColumnParallelLinear
RowParallelLinear = mpu.RowParallelLinear
ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear
RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear

if get_env_device() == "npu":
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None:
ColumnSequenceParallelLinear = MC2ColumnSeqParallelLinear
RowSequenceParallelLinear = MC2RowSeqParallelLinear
elif get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import ColumnParallelLinear as XPUColumnParallelLinear
from paddle_xpu.layers.nn import Linear as XPULinear
from paddle_xpu.layers.nn import RowParallelLinear as XPURowParallelLinear
from paddle_xpu.layers.nn.sequence_parallel import (
XPUColumnSequenceParallelLinear,
XPURowSequenceParallelLinear,
)

Linear = XPULinear
ColumnParallelLinear = XPUColumnParallelLinear
RowParallelLinear = XPURowParallelLinear
ColumnSequenceParallelLinear = XPUColumnSequenceParallelLinear
RowSequenceParallelLinear = XPURowSequenceParallelLinear
except ImportError:
# If paddle_xpu is not installed, just use Paddle's native Linear implementations
pass
else:
# By default, use Paddle's native Linear implementations
pass
98 changes: 65 additions & 33 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ def swiglu(x, y=None):
init_name_mappings,
)
from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies
from paddlenlp.transformers.mc2_parallel_linear import (
MC2ColumnSeqParallelLinear,
MC2RowSeqParallelLinear,
)
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand All @@ -74,6 +70,8 @@ def swiglu(x, y=None):
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device

from .. import linear_utils
from ..linear_utils import Linear
from ..segment_parallel_utils import ReshardLayer
from .configuration import (
LLAMA_PRETRAINED_INIT_CONFIGURATION,
Expand Down Expand Up @@ -410,6 +408,15 @@ def forward(self, hidden_states):
if self.config.use_fused_rms_norm:
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0]
elif get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
except ImportError:
raise NotImplementedError(
f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature"
)
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)

if paddle.in_dynamic_mode():
Expand Down Expand Up @@ -571,15 +578,11 @@ def __init__(self, config):
self.fuse_attention_ffn = config.fuse_attention_ffn

if config.sequence_parallel:
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None:
ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

if config.tensor_parallel_degree > 1:
if config.fuse_attention_ffn:
Expand Down Expand Up @@ -611,15 +614,29 @@ def __init__(self, config):
)
else:
if config.fuse_attention_ffn:
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
else:
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)

self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False)

def forward(self, x):
if self.fuse_attention_ffn:
# FIXME(yangjianbang): use paddle's native swiglu
if get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

out = self.gate_up_fused_proj(x)
out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True)
out = self.down_proj(out)
return out
except ImportError:
gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1)
out = self.down_proj(F.silu(gate_out) * up_out)
return out

x = swiglu(self.gate_up_fused_proj(x))
else:
x = swiglu(self.gate_proj(x), self.up_proj(x))
Expand Down Expand Up @@ -680,7 +697,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
)

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope and get_env_device() != "npu":
if self.use_fused_rope and get_env_device() not in ["npu", "xpu"]:
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
warnings.warn(
"Enable fuse rope in the config, but fuse rope is not available. "
Expand All @@ -689,15 +706,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
self.use_fused_rope = False

if config.sequence_parallel:
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None:
ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

if config.tensor_parallel_degree > 1:
if self.fuse_attention_qkv:
Expand Down Expand Up @@ -728,36 +741,36 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
gather_output=False,
)
else:
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)

else:
if self.fuse_attention_qkv:
self.qkv_proj = nn.Linear(
self.qkv_proj = Linear(
self.hidden_size,
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
else:
self.q_proj = nn.Linear(
self.q_proj = Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
)
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
Expand All @@ -771,7 +784,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
input_is_parallel=True,
)
else:
self.o_proj = nn.Linear(
self.o_proj = Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
Expand Down Expand Up @@ -1469,6 +1482,11 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16")
expanded_attn_mask = expanded_attn_mask.astype("float16")
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
elif get_env_device() == "xpu":
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
expanded_attn_mask = expanded_attn_mask.astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
else:
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
return expanded_attn_mask
Expand Down Expand Up @@ -1748,6 +1766,15 @@ def __init__(self, config: LlamaConfig):
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False
if self.weight.is_distributed:
self.weight.split_axis = 1
if get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import ( # noqa: F401
parallel_matmul as xpu_parallel_matmul,
)

self.xpu_parallel_matmul = xpu_parallel_matmul()
except ImportError:
self.xpu_parallel_matmul = None

def forward(self, hidden_states, tensor_parallel_output=None):
if self.config.sequence_parallel:
Expand All @@ -1761,7 +1788,12 @@ def forward(self, hidden_states, tensor_parallel_output=None):
if tensor_parallel_output is None:
tensor_parallel_output = self.config.tensor_parallel_output

logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None:
logits = self.xpu_parallel_matmul(
hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training
)
else:
logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
return logits


Expand Down

0 comments on commit ba9d9bd

Please sign in to comment.