Skip to content

Commit

Permalink
update paddlenlp (PaddlePaddle#8267)
Browse files Browse the repository at this point in the history
  • Loading branch information
Liujie0926 authored and dynamicheart committed Apr 17, 2024
1 parent ee88c12 commit 118ddd6
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 41 deletions.
6 changes: 6 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,12 @@ def main():
config.num_attention_heads % config.sep_parallel_degree == 0
), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}"

if paddle.is_compiled_with_xpu() and training_args.gradient_accumulation_steps > 1:
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401

LinearConfig.enable_accumulate_steps_opt()
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)

print("Final pre-training config:", config)

# Set the dtype for loading model
Expand Down
84 changes: 68 additions & 16 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
return assignment_list


def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True):
def parallel_matmul(matmul_op, x: Tensor, y: Tensor, tensor_parallel_output=True):
is_fleet_init = True
tensor_parallel_degree = 1
try:
Expand All @@ -192,15 +192,15 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True):
if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
# if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group)
logits = paddle.matmul(input_parallel, y, transpose_y=False)
logits = matmul_op(input_parallel, y, transpose_y=False)

if tensor_parallel_output:
return logits

return paddle.distributed.collective._c_concat(logits, group=model_parallel_group)

else:
logits = paddle.matmul(x, y, transpose_y=False)
logits = matmul_op(x, y, transpose_y=False)
return logits


Expand Down Expand Up @@ -413,6 +413,10 @@ def forward(self, hidden_states):
if self.config.use_fused_rms_norm:
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0]
elif get_env_device() == "xpu":
import paddle_xpu_nn

return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)

if paddle.in_dynamic_mode():
Expand Down Expand Up @@ -582,13 +586,26 @@ def __init__(self, config):

ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear
elif get_env_device() == "xpu":
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401
XPUColumnSequenceParallelLinear,
XPURowSequenceParallelLinear,
)

ColumnParallelLinear = XPUColumnSequenceParallelLinear
RowParallelLinear = XPURowSequenceParallelLinear
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

if get_env_device() == "xpu":
Linear = paddle_xpu.layers.nn.Linear # noqa: F821
else:
Linear = nn.Linear

if config.tensor_parallel_degree > 1:
if config.fuse_attention_ffn:
self.gate_up_fused_proj = ColumnParallelLinear(
Expand Down Expand Up @@ -619,12 +636,12 @@ def __init__(self, config):
)
else:
if config.fuse_attention_ffn:
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
else:
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)

self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False)

def forward(self, x):
if self.fuse_attention_ffn:
Expand Down Expand Up @@ -689,7 +706,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope and get_env_device() != "npu":
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
if (
"gpu" not in paddle.device.get_device()
or "xpu" not in paddle.device.get_device()
or fused_rotary_position_embedding is None
):
warnings.warn(
"Enable fuse rope in the config, but fuse rope is not available. "
"Will disable fuse rope. Try using latest gpu version of Paddle."
Expand All @@ -705,13 +726,26 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):

ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear
elif get_env_device() == "xpu":
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401
XPUColumnSequenceParallelLinear,
XPURowSequenceParallelLinear,
)

ColumnParallelLinear = XPUColumnSequenceParallelLinear
RowParallelLinear = XPURowSequenceParallelLinear
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

if get_env_device() == "xpu":
Linear = paddle_xpu.layers.nn.Linear # noqa: F821
else:
Linear = nn.Linear

if config.tensor_parallel_degree > 1:
if self.fuse_attention_qkv:
self.qkv_proj = ColumnParallelLinear(
Expand Down Expand Up @@ -741,36 +775,36 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
gather_output=False,
)
else:
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)

else:
if self.fuse_attention_qkv:
self.qkv_proj = nn.Linear(
self.qkv_proj = Linear(
self.hidden_size,
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
else:
self.q_proj = nn.Linear(
self.q_proj = Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
)
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
Expand All @@ -784,7 +818,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
input_is_parallel=True,
)
else:
self.o_proj = nn.Linear(
self.o_proj = Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
Expand Down Expand Up @@ -1428,6 +1462,11 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16")
expanded_attn_mask = expanded_attn_mask.astype("float16")
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
elif get_env_device() == "xpu":
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
expanded_attn_mask = expanded_attn_mask.astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
else:
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
return expanded_attn_mask
Expand Down Expand Up @@ -1708,6 +1747,13 @@ def __init__(self, config: LlamaConfig):
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False
if self.weight.is_distributed:
self.weight.split_axis = 1
if paddle.is_compiled_with_xpu():
from paddle_xpu.layers.nn import xpu_matmul # noqa: F401

self._xpu_matmul = xpu_matmul()
self.matmul_op = self._xpu_matmul.forward
else:
self.matmul_op = paddle.matmul

def forward(self, hidden_states, tensor_parallel_output=None):
if self.config.sequence_parallel:
Expand All @@ -1721,7 +1767,13 @@ def forward(self, hidden_states, tensor_parallel_output=None):
if tensor_parallel_output is None:
tensor_parallel_output = self.config.tensor_parallel_output

logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
matmul_op = self.matmul_op
if paddle.is_compiled_with_xpu():
from functools import partial

matmul_op = partial(matmul_op, training=self.training)

logits = parallel_matmul(matmul_op, hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
return logits


Expand Down
4 changes: 2 additions & 2 deletions scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1854,8 +1854,6 @@ function before_hook_for_gpt() {
if [[ $FLAGS_install_deps == 0 ]];then
echo -e "\033[31m ---- Install requirements for GPT auto cases \033[0m"
python -m pip install -r requirements.txt --force-reinstall
python -m pip install --no-cache-dir https://paddlenlp.bj.bcebos.com/wheels/paddlenlp-ci-py3-none-any.whl --force-reinstall --no-dependencies
python -c "import paddlenlp; print('paddlenlp commit:',paddlenlp.version.commit)";
else
echo -e "\033[31m ---- Skip install requirements for GPT auto cases \033[0m"
fi
Expand Down Expand Up @@ -1886,6 +1884,8 @@ function before_hook_for_llama() {
env | grep FLAGS
export http_proxy=${proxy}
export https_proxy=${proxy}
python -m pip install -r $root_path/requirements.txt
python -m pip install -r $root_path/requirements-dev.txt
if [[ ! $FLAGS_download_data =~ "llama" ]];then
echo -e "\033[31m ---- Download LLaMA data \033[0m"
rm -rf data
Expand Down
4 changes: 1 addition & 3 deletions scripts/distribute/ci_case_dy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,6 @@ function before_hook_for_gpt() {
if [[ $FLAGS_install_deps == 0 ]];then
echo -e "\033[31m ---- Install requirements for GPT dygraph cases \033[0m"
python -m pip install -r requirements.txt --force-reinstall
python -m pip install --no-cache-dir https://paddlenlp.bj.bcebos.com/wheels/paddlenlp-ci-py3-none-any.whl --force-reinstall --no-dependencies
python -c "import paddlenlp; print('paddlenlp commit:',paddlenlp.version.commit)";
else
echo -e "\033[31m ---- Skip install requirements for GPT dygraph cases \033[0m"
fi
Expand Down Expand Up @@ -614,7 +612,7 @@ function before_hook_for_llm_gpt() {
export http_proxy=${proxy}
export https_proxy=${proxy}
python -m pip install -r $root_path/requirements.txt
python -m pip install regex
python -m pip install -r $root_path/requirements-dev.txt
if [[ ! $FLAGS_download_data =~ "llm_gpt" ]];then
echo -e "\033[31m ---- Download llm GPT data \033[0m"
rm -rf data
Expand Down
35 changes: 15 additions & 20 deletions scripts/distribute/run_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,24 @@ install_paddle(){
}

install_paddlenlp(){
echo -e "\033[31m ---- Install paddlenlp \033"
cd ${nlp_dir}
echo -e "\033[31m ---- Install paddlenlp by set PYTHONPATH \033"
export PYTHONPATH=${nlp_dir}:$PYTHONPATH
sed -i -e "s/paddlenlp/#paddlenlp/g" model_zoo/gpt-3/requirements.txt
export http_proxy=${proxy} && export https_proxy=${proxy}
python -m pip uninstall paddlenlp -y
rm -rf build/ && rm -rf paddlenlp.egg-info/ && rm -rf dist/
python -m pip install --ignore-installed -r requirements.txt
python -m pip install --ignore-installed -r requirements-dev.txt
python setup.py install
python setup.py build_ext
python setup.py bdist_wheel
unset http_proxy && unset https_proxy
cd -
# export http_proxy=${proxy} && export https_proxy=${proxy}
# python -m pip uninstall paddlenlp -y
# rm -rf build/ && rm -rf paddlenlp.egg-info/ && rm -rf dist/
# python -m pip install --ignore-installed -r requirements.txt
# python -m pip install --ignore-installed -r requirements-dev.txt
# python setup.py install
# python setup.py build_ext
# python setup.py bdist_wheel
# unset http_proxy && unset https_proxy
# cd -
python -c "import paddlenlp; print('paddlenlp commit:',paddlenlp.version.commit)";
}
####################################
get_diff_TO_case(){
cd ${nlp_dir}
export FLAGS_paddlenlp=0
for file_name in `git diff --numstat upstream/${AGILE_COMPILE_BRANCH} |awk '{print $NF}'`;do
arr_file_name=(${file_name//// })
dir1=${arr_file_name[0]}
Expand All @@ -70,9 +69,6 @@ for file_name in `git diff --numstat upstream/${AGILE_COMPILE_BRANCH} |awk '{pri
dir4=${arr_file_name[3]}
file_item=$dir1/$dir2/$dir3/$dir4
echo "file_name:"${file_name}, "path:"${file_item}
if [[ ${dir1} =~ "paddlenlp" ]];then
export FLAGS_paddlenlp=1
fi
if [ ! -f ${file_name} ];then # 针对pr删掉文件
continue
elif [[ ${file_name##*.} == "md" ]] || [[ ${file_name##*.} == "rst" ]] || [[ ${dir1} == "docs" ]];then
Expand Down Expand Up @@ -129,10 +125,9 @@ if [[ ${#case_list[*]} -ne 0 ]];then

# Install paddle
install_paddle
if [[ FLAGS_paddlenlp -eq 1 ]] || [[ $(contain_case llama_auto ${case_list[@]}; echo $?) -eq 1 ]];then
# 安装本地paddlenlp
install_paddlenlp
fi
# Install paddlenlp
install_paddlenlp

case_num=1
export FLAGS_install_deps=0
export FLAGS_download_data=""
Expand Down

0 comments on commit 118ddd6

Please sign in to comment.