From 5e1f01fd974f63ec1fc23ab41cd87809d3929711 Mon Sep 17 00:00:00 2001 From: thinking-computer <145450064+thinking-computer@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:11:42 +0800 Subject: [PATCH 01/11] [Custom Devices] feat(sdaa): support sdaa backend infer (#9570) 1.add sdaa python paddlenlp_ops setup and README 2.update llm scripts and README --- csrc/sdaa/README.md | 15 ++ csrc/sdaa/python/paddlenlp_ops/__init__.py | 15 ++ csrc/sdaa/setup_sdaa.py | 59 +++++++ docs/llm/sdaa/llama/README.md | 1 + llm/docs/predict/inference.md | 15 +- llm/docs/predict/installation.md | 4 +- llm/sdaa/llama/README.md | 187 +++++++++++++++++++++ llm/sdaa/llama/dynamic_infer_llama_sdaa.sh | 17 ++ llm/sdaa/llama/static_export_llama_sdaa.sh | 18 ++ llm/sdaa/llama/static_infer_llama_sdaa.sh | 18 ++ 10 files changed, 341 insertions(+), 8 deletions(-) create mode 100644 csrc/sdaa/README.md create mode 100644 csrc/sdaa/python/paddlenlp_ops/__init__.py create mode 100644 csrc/sdaa/setup_sdaa.py create mode 120000 docs/llm/sdaa/llama/README.md create mode 100644 llm/sdaa/llama/README.md create mode 100644 llm/sdaa/llama/dynamic_infer_llama_sdaa.sh create mode 100644 llm/sdaa/llama/static_export_llama_sdaa.sh create mode 100644 llm/sdaa/llama/static_infer_llama_sdaa.sh diff --git a/csrc/sdaa/README.md b/csrc/sdaa/README.md new file mode 100644 index 000000000000..636ce1ccdca8 --- /dev/null +++ b/csrc/sdaa/README.md @@ -0,0 +1,15 @@ +# PaddleNLP 自定义 OP + +此文档介绍如何编译安装 PaddleNLP SDAA 自定义 OP。 + +# 1. 安装 PaddleCustomDevice + +参考 [PaddleCustomDevice SDAA 安装文档](https://github.com/PaddlePaddle/PaddleCustomDevice/blob/develop/backends/sdaa/README_cn.md) 进行安装 + + +# 2. 安装 paddlenlp_ops +```shell +python setup_sdaa.py build bdist_wheel + +pip install dist/paddlenlp_ops*.whl +``` diff --git a/csrc/sdaa/python/paddlenlp_ops/__init__.py b/csrc/sdaa/python/paddlenlp_ops/__init__.py new file mode 100644 index 000000000000..1a638a971890 --- /dev/null +++ b/csrc/sdaa/python/paddlenlp_ops/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_sdaa.sdaa_ext import * diff --git a/csrc/sdaa/setup_sdaa.py b/csrc/sdaa/setup_sdaa.py new file mode 100644 index 000000000000..a8722ecaac0a --- /dev/null +++ b/csrc/sdaa/setup_sdaa.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from setuptools import Distribution, setup + +packages = [] +package_data = {} + + +class BinaryDistribution(Distribution): + def has_ext_modules(self): + return True + + +def main(): + setup( + name="paddlenlp_ops", + version="0.0.0", + description="PaddleNLP SDAA CustomOps", + long_description="", + long_description_content_type="text/markdown", + author_email="Paddle-better@baidu.com", + maintainer="PaddlePaddle", + maintainer_email="Paddle-better@baidu.com", + project_urls={}, + license="Apache Software License", + packages=[ + "paddlenlp_ops", + ], + include_package_data=True, + package_data={ + "": ["*.py"], + }, + package_dir={ + "": "python", + }, + zip_safe=False, + distclass=BinaryDistribution, + entry_points={"console_scripts": []}, + classifiers=[], + keywords="PaddleNLP SDAA CustomOps", + ) + + +if __name__ == "__main__": + main() diff --git a/docs/llm/sdaa/llama/README.md b/docs/llm/sdaa/llama/README.md new file mode 120000 index 000000000000..d20a5694cb1c --- /dev/null +++ b/docs/llm/sdaa/llama/README.md @@ -0,0 +1 @@ +../../../../llm/sdaa/llama/README.md \ No newline at end of file diff --git a/llm/docs/predict/inference.md b/llm/docs/predict/inference.md index bbeeff3d9aae..b2615bc2446a 100644 --- a/llm/docs/predict/inference.md +++ b/llm/docs/predict/inference.md @@ -39,13 +39,13 @@ PaddleNLP 中已经添加高性能推理模型相关实现,已验证过的模 PaddleNLP 提供了多种硬件平台和精度支持,包括: -| Precision | Hopper| Ada | Ampere | Turing | Volta | 昆仑XPU | 昇腾NPU | 海光K100 | 燧原GCU | x86 CPU | -|:--------------:|:-----:|:---:|:------:|:------:|:-----:|:------:|:-------:|:-------:|:------:|:-------:| -| FP32 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| FP16 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| BF16 | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | -| INT8 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | -| FP8 | 🚧 | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| Precision | Hopper| Ada | Ampere | Turing | Volta | 昆仑XPU | 昇腾NPU | 海光K100 | 燧原GCU | 太初SDAA| x86 CPU | +|:--------------:|:-----:|:---:|:------:|:------:|:-----:|:------:|:-------:|:-------:|:------:|:------:|:-------:| +| FP32 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| FP16 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| BF16 | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | +| INT8 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | +| FP8 | 🚧 | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ## 3. 推理参数 @@ -196,6 +196,7 @@ python ./predict/predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat -- - [昇腾NPU](../../npu/llama/README.md) - [海光K100](../dcu_install.md) - [燧原GCU](../../gcu/llama/README.md) +- [太初SDAA](../../sdaa/llama/README.md) - [X86 CPU](../cpu_install.md) ## 致谢 diff --git a/llm/docs/predict/installation.md b/llm/docs/predict/installation.md index e88fb5b551f6..4d077c1c9ed6 100644 --- a/llm/docs/predict/installation.md +++ b/llm/docs/predict/installation.md @@ -16,6 +16,8 @@ cd PaddleNLP/csrc && python setup_cuda.py install cd PaddleNLP/csrc/xpu/src && sh cmake_build.sh #DCU设备安装自定义算子 cd PaddleNLP/csrc && python setup_hip.py install +#SDAA设备安装自定义算子 +cd PaddleNLP/csrc/sdaa && python setup_sdaa.py install ``` 到达运行目录,即可开始: @@ -32,4 +34,4 @@ cd PaddleNLP/llm 获取最佳推理性能: -- [最佳实践](./best_practices.md) \ No newline at end of file +- [最佳实践](./best_practices.md) diff --git a/llm/sdaa/llama/README.md b/llm/sdaa/llama/README.md new file mode 100644 index 000000000000..0e5afa03a42c --- /dev/null +++ b/llm/sdaa/llama/README.md @@ -0,0 +1,187 @@ +## 🚣‍♂️ 使用PaddleNLP在太初sdaa 下运行Llama-2-13b-chat模型 🚣 + +PaddleNLP在太初sdaa上对Llama-2-13b-chat模型进行了深度适配和优化,实现了sdaa device推理入口和GPU的基本统一,仅需修改device即可完成推理任务的迁移。 + +## 🚀 快速开始 🚀 + +### 0. 机器准备。快速开始之前,您需要准备一台插有太初T100加速卡的机器,要求如下: + +| 芯片类型 | 驱动版本 | +| --- | --- | +| 太初T100 | 1.3.0| + + +### 1. 环境准备:(这将花费您5~15min 时间) + +#### 1.1 拉取镜像 +```bash +# 注意此镜像包含预编译的飞桨安装包, TecoDriver, TecoToolKit等,可以一键运行paddlenlp模型 +wget http://mirrors.tecorigin.com/repository/teco-3rd-repo/custom_device/ubuntu22.04/x86_64/1.3.0/paddle_sdaa_1.3.0_llm_infer.tar +docker load < paddle_sdaa_1.3.0_llm_infer.tar +``` + +#### 1.2 参考如下命令启动容器 +```bash +docker run -itd --name="paddle-sdaa-dev" --net=host --privileged --cap-add SYS_PTRACE --cap-add SYS_ADMIN --shm-size 128g jfrog.tecorigin.net/tecotp-docker/release/ubuntu22.04/x86_64/paddle_sdaa:1.3.0-llm-infer /bin/bash +``` + +#### 1.3 下载PaddleNLP仓库代码,并安装依赖 +```bash +# PaddleNLP是基于PaddlePaddle『飞桨』的自然语言处理和大语言模型(LLM)开发库,存放了基于『飞桨』框架实现的各种大模型,Llama-2-13b-chat模型也包含其中。为了便于您更好地使用PaddleNLP,您需要clone整个仓库。 +git clone https://github.com/PaddlePaddle/PaddleNLP.git +cd PaddleNLP +export PYTHONPATH=/path/to/PaddleNLP:$PYTHONPATH +pip install -r requirements.txt +cd csrc/sdaa && python setup_sdaa.py install && cd ../../llm/sdaa/llama +``` +### 2. 推理:(这将花费您15~30min时间) +#### 2.1 动态图分布式推理 + +执行如下命令进行推理: +```bash +bash dynamic_infer_llama_sdaa.sh +``` +首次推理会自动下载权重,可以使用自动下载的权重,或者下载之后指定权重路径。成功运行后,可以查看到推理结果的生成。 + +样例将下载的权重meta-llama/Llama-2-13b-chat文件夹保存到/workspace/weights,示例如下: +``` +[2024-12-10 15:42:51,992] [ INFO] - set state for layer 30 +[2024-12-10 15:42:53,666] [ INFO] - set state for layer 31 +[2024-12-10 15:42:55,202] [ INFO] - set state for layer 32 +[2024-12-10 15:42:56,724] [ INFO] - set state for layer 33 +[2024-12-10 15:42:58,314] [ INFO] - set state for layer 34 +[2024-12-10 15:43:00,041] [ INFO] - set state for layer 35 +[2024-12-10 15:43:01,515] [ INFO] - set state for layer 36 +[2024-12-10 15:43:03,034] [ INFO] - set state for layer 37 +[2024-12-10 15:43:04,746] [ INFO] - set state for layer 38 +[2024-12-10 15:43:06,390] [ INFO] - set state for layer 39 +[2024-12-10 15:43:08,682] [ INFO] - We are using to load '/workspace/weights/meta-llama/Llama-2-13b-chat'. +[2024-12-10 15:43:08,682] [ INFO] - Loading configuration file /workspace/weights/meta-llama/Llama-2-13b-chat/config.json +[2024-12-10 15:43:08,683] [ INFO] - Loading configuration file /workspace/weights/meta-llama/Llama-2-13b-chat/generation_config.json +[2024-12-10 15:43:08,752] [ INFO] - Start predict +[2024-12-10 15:43:08,789] [ INFO] - We are using to load '/workspace/weights/meta-llama/Llama-2-13b-chat'. +[2024-12-10 15:43:08,806] [ INFO] - Start read result message +[2024-12-10 15:43:08,806] [ INFO] - Current path is /workspace/paddlenlp/llm +[2024-12-10 15:43:29,178] [ INFO] - running spend 20.372194528579712 +[2024-12-10 15:43:29,187] [ INFO] - Finish read result message +[2024-12-10 15:43:29,192] [ INFO] - End predict +***********Source********** +解释一下温故而知新 +***********Target********** + +***********Output********** + "温故而知新" (wēn gù er zhī xīn) is a Chinese idiom that means "to understand the old in order to know the new." It is often used to convey the idea that one must have a deep understanding of the past and traditional ways of doing things in order to truly appreciate and understand new ideas and innovations. + +The phrase is often used in the context of education, where students are encouraged to study the classics and learn from the past in order to gain a solid foundation for understanding new concepts and ideas. It is also used in business and technology, where companies may look to the past for inspiration and guidance as they develop new products and services. + +In essence, "温故而知新" suggests that one cannot truly understand the new without first understanding the old, and that a deep appreciation for the past is essential for making progress and innovation. +``` +#### 2.2 静态图分布式推理 + +##### 2.2.1 静态图导出 + +执行如下命令进行静态图导出,为静态图分布式推理做好准备: +```bash +bash static_export_llama_sdaa.sh +``` +成功运行后,可以查看到模型导出的结果,样例如下: +```bash +[2024-12-10 15:30:28,991] [ INFO] - set state for layer 24 +[2024-12-10 15:30:30,246] [ INFO] - set state for layer 25 +[2024-12-10 15:30:31,586] [ INFO] - set state for layer 26 +[2024-12-10 15:30:32,892] [ INFO] - set state for layer 27 +[2024-12-10 15:30:34,228] [ INFO] - set state for layer 28 +[2024-12-10 15:30:35,530] [ INFO] - set state for layer 29 +[2024-12-10 15:30:36,925] [ INFO] - set state for layer 30 +[2024-12-10 15:30:38,233] [ INFO] - set state for layer 31 +[2024-12-10 15:30:39,635] [ INFO] - set state for layer 32 +[2024-12-10 15:30:40,992] [ INFO] - set state for layer 33 +[2024-12-10 15:30:42,375] [ INFO] - set state for layer 34 +[2024-12-10 15:30:43,717] [ INFO] - set state for layer 35 +[2024-12-10 15:30:45,076] [ INFO] - set state for layer 36 +[2024-12-10 15:30:46,423] [ INFO] - set state for layer 37 +[2024-12-10 15:30:47,827] [ INFO] - set state for layer 38 +[2024-12-10 15:30:49,216] [ INFO] - set state for layer 39 +[2024-12-10 15:30:51,136] [ INFO] - We are using to load '/workspace/weights/meta-llama/Llama-2-13b-chat'. +[2024-12-10 15:30:51,136] [ INFO] - Loading configuration file /workspace/weights/meta-llama/Llama-2-13b-chat/config.json +[2024-12-10 15:30:51,137] [ INFO] - Loading configuration file /workspace/weights/meta-llama/Llama-2-13b-chat/generation_config.json +/root/miniconda3/envs/paddle_env/lib/python3.10/site-packages/paddle/jit/dy2static/program_translator.py:747: UserWarning: full_graph=False don't support input_spec arguments. It will not produce any effect. +You can set full_graph=True, then you can assign input spec. + + warnings.warn( +/root/miniconda3/envs/paddle_env/lib/python3.10/site-packages/paddle/jit/api.py:1106: UserWarning: What you save is a function, and `jit.save` will generate the name of the model file according to `path` you specify. When loading these files with `jit.load`, you get a `TranslatedLayer` whose inference result is the same as the inference result of the function you saved. + warnings.warn( +I1210 15:30:58.707722 1174678 program_interpreter.cc:242] New Executor is Running. +[2024-12-10 15:31:10,381] [ INFO] - Configuration saved in ./output_dir/exported_model/llama2_13b_chat_wint8_block_size32/config.json +[2024-12-10 15:31:10,382] [ INFO] - Configuration saved in ./output_dir/exported_model/llama2_13b_chat_wint8_block_size32/generation_config.json +[2024-12-10 15:31:10,382] [ INFO] - tokenizer config file saved in ./output_dir/exported_model/llama2_13b_chat_wint8_block_size32/tokenizer_config.json +[2024-12-10 15:31:10,382] [ INFO] - Special tokens file saved in ./output_dir/exported_model/llama2_13b_chat_wint8_block_size32/special_tokens_map.json +[2024-12-10 15:31:10,383] [ INFO] - Chat-template config file saved in ./output_dir/exported_model/llama2_13b_chat_wint8_block_size32/chat_template.json +LAUNCH INFO 2024-12-10 15:31:12,346 Pod completed +LAUNCH INFO 2024-12-10 15:31:12,347 Exit code 0 +``` +##### 2.2.2 静态图分布式推理 + +执行如下命令进行静态图分布式推理: +```bash +bash static_infer_llama_sdaa.sh +``` +成功运行后,可以查看到推理结果的生成,样例如下: +```bash +[2024-12-10 15:36:24,150] [ INFO] topology.py:370 - Total 4 data comm group(s) create successfully! +[2024-12-10 15:36:24,150] [ INFO] topology.py:370 - Total 1 model comm group(s) create successfully! +[2024-12-10 15:36:24,150] [ INFO] topology.py:370 - Total 4 sharding comm group(s) create successfully! +[2024-12-10 15:36:24,150] [ INFO] topology.py:290 - HybridParallelInfo: rank_id: 0, mp_degree: 4, sharding_degree: 1, pp_degree: 1, dp_degree: 1, sep_degree: 1, mp_group: [0, 1, 2, 3], sharding_group: [0], pp_group: [0], dp_group: [0], sep:group: None, check/clip group: [0, 1, 2, 3] +[2024-12-10 15:36:24,152] [ INFO] - We are using to load 'output_dir/exported_model/llama2_13b_chat_wint8_block_size32'. +[2024-12-10 15:36:24,164] [ INFO] - We are using to load 'output_dir/exported_model/llama2_13b_chat_wint8_block_size32'. +[2024-12-10 15:36:24,164] [ INFO] - Loading configuration file output_dir/exported_model/llama2_13b_chat_wint8_block_size32/config.json +[2024-12-10 15:36:24,165] [ INFO] - We are using to load 'output_dir/exported_model/llama2_13b_chat_wint8_block_size32'. +[2024-12-10 15:36:24,165] [ INFO] - Loading configuration file output_dir/exported_model/llama2_13b_chat_wint8_block_size32/config.json +[2024-12-10 15:36:24,198] [ INFO] - We are using to load 'output_dir/exported_model/llama2_13b_chat_wint8_block_size32'. +[2024-12-10 15:36:24,198] [ INFO] - Loading configuration file output_dir/exported_model/llama2_13b_chat_wint8_block_size32/config.json +[2024-12-10 15:36:24,199] [ INFO] - Loading configuration file output_dir/exported_model/llama2_13b_chat_wint8_block_size32/generation_config.json +I1210 15:36:24.239424 1334951 analysis_predictor.cc:2142] MKLDNN is enabled +I1210 15:36:24.239473 1334951 analysis_predictor.cc:2167] CustomDevice is enabled +I1210 15:36:24.239486 1334951 analysis_predictor.cc:2210] Model is mixed precision type with float16, we will use a new PassStrategy. Note that only GPU/XPU backend is supported for now. +I1210 15:36:24.239490 1334951 analysis_predictor.cc:2259] Ir optimization is turned off, no ir pass will be executed. +--- Running analysis [ir_graph_build_pass] +I1210 15:36:24.260483 1334951 executor.cc:183] Old Executor is Running. +--- Running analysis [ir_analysis_pass] +--- Running analysis [ir_params_sync_among_devices_pass] +I1210 15:36:25.863914 1334951 ir_params_sync_among_devices_pass.cc:140] Sync params from CPU to sdaa:0 +--- Running analysis [adjust_cudnn_workspace_size_pass] +--- Running analysis [inference_op_replace_pass] +--- Running analysis [save_optimized_model_pass] +--- Running analysis [ir_graph_to_program_pass] +I1210 15:36:29.991195 1334951 analysis_predictor.cc:2348] ======= ir optimization completed ======= +I1210 15:36:30.000306 1334951 gen_comm_id_helper.cc:212] Server listening on: 127.0.1.1:36942 successful. +I1210 15:36:30.088883 1334951 task_node.cc:43] Constructing TaskNode for DistModelInf. The TaskNode's id is: 0. And the TaskNode's max_run_time and max_slot_num will be set to 1. +LAUNCH INFO 2024-12-10 15:37:24,254 Pod completed +LAUNCH INFO 2024-12-10 15:37:24,254 Exit code 0 +I1210 15:36:30.189157 1334951 server.cpp:1107] Server[paddle::distributed::MessageServiceImpl] is serving on port=36942. +I1210 15:36:30.189195 1334951 server.cpp:1110] Check out http://dmx-19:36942 in web browser. +I1210 15:36:30.189320 1334951 message_bus.cc:201] Message bus's listen port thread starts successful. +[2024-12-10 15:36:31,284] [ INFO] - Start predict +[2024-12-10 15:36:31,296] [ INFO] - preprocess spend 0.010512113571166992 +[2024-12-10 15:36:31,355] [ INFO] - We are using to load 'output_dir/exported_model/llama2_13b_chat_wint8_block_size32'. +[2024-12-10 15:36:31,378] [ INFO] - Start read result message +[2024-12-10 15:36:31,378] [ INFO] - Current path is /workspace/paddlenlp/llm +[2024-12-10 15:37:22,118] [ INFO] - running spend 50.736462116241455 +[2024-12-10 15:37:22,125] [ INFO] - Finish read result message +[2024-12-10 15:37:22,132] [ INFO] - End predict +***********Source********** +解释一下温故而知新 +***********Target********** + +***********Output********** + "温故而知新" (wēn gù er zhī xīn) is a Chinese idiom that means "to know the old in order to discern the new." It is often used to describe the idea that one can gain a deeper understanding of something new by studying and appreciating the past. + +The word "温" (wēn) in this idiom means "old" or "past," and "故" (gù) means "olden days" or "former times." The word "知" (zhī) means "to know" or "to understand," and "新" (xīn) means "new." + +The idiom "温故而知新" suggests that by studying and understanding the past, one can gain a deeper appreciation for the present and make more informed decisions about the future. It is often used in the context of learning from history, understanding cultural traditions, and appreciating the value of experience and wisdom. + +For example, if someone is trying a new type of food for the first time, they might say "I need to study the old recipes to know the new flavors" (我需要学习古老的菜谱,才能了解新的味道). This means that by understanding the traditional methods and ingredients used in the past, they can better appreciate the new dish and its unique qualities. + +Overall, "温故而知新" is a reminder that understanding the past can help us navigate the present and make more informed decisions about the future. +I1210 15:37:22.926474 1334951 server.cpp:1167] Server[paddle::distributed::MessageServiceImpl] is going to quit +``` diff --git a/llm/sdaa/llama/dynamic_infer_llama_sdaa.sh b/llm/sdaa/llama/dynamic_infer_llama_sdaa.sh new file mode 100644 index 000000000000..cff575a74497 --- /dev/null +++ b/llm/sdaa/llama/dynamic_infer_llama_sdaa.sh @@ -0,0 +1,17 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SDAA_VISIBLE_DEVICES=0,1,2,3 PADDLE_XCCL_BACKEND=sdaa \ +python -m paddle.distributed.launch ./../../predict/predictor.py --model_name_or_path meta-llama/Llama-2-13b-chat --inference_model --dtype float16 --block_attn --quant_type weight_only_int8 --device sdaa --block_size 32 + diff --git a/llm/sdaa/llama/static_export_llama_sdaa.sh b/llm/sdaa/llama/static_export_llama_sdaa.sh new file mode 100644 index 000000000000..e1b25a853ccf --- /dev/null +++ b/llm/sdaa/llama/static_export_llama_sdaa.sh @@ -0,0 +1,18 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export FLAGS_enable_pir_api=0 +SDAA_VISIBLE_DEVICES=0,1,2,3 PADDLE_XCCL_BACKEND=sdaa \ +python -m paddle.distributed.launch ./../../predict/export_model.py --model_name_or_path meta-llama/Llama-2-13b-chat --inference_model --output_path ./output_dir/exported_model/llama2_13b_chat_wint8_block_size32 --dtype float16 --block_attn --quant_type weight_only_int8 --device sdaa --block_size 32 + diff --git a/llm/sdaa/llama/static_infer_llama_sdaa.sh b/llm/sdaa/llama/static_infer_llama_sdaa.sh new file mode 100644 index 000000000000..ed38f9401a9f --- /dev/null +++ b/llm/sdaa/llama/static_infer_llama_sdaa.sh @@ -0,0 +1,18 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export FLAGS_enable_pir_api=0 +SDAA_VISIBLE_DEVICES=0,1,2,3 PADDLE_XCCL_BACKEND=sdaa \ +python -m paddle.distributed.launch ./../../predict/predictor.py --model_name_or_path output_dir/exported_model/llama2_13b_chat_wint8_block_size32 --dtype float16 --mode static --inference_model 1 --quant_type weight_only_int8 --block_attn 1 --device sdaa --block_size 32 + From f3ba5b3c6a9b1612c6d87fd069fc8a0d45dfb844 Mon Sep 17 00:00:00 2001 From: lugimzzz <63761690+lugimzzz@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:41:01 +0800 Subject: [PATCH 02/11] [llm]update dpo criterion (#9620) * update dpo criterion * update dpo criterion --- paddlenlp/trl/dpo_criterion.py | 110 ++++++++++++++++++--------------- 1 file changed, 60 insertions(+), 50 deletions(-) diff --git a/paddlenlp/trl/dpo_criterion.py b/paddlenlp/trl/dpo_criterion.py index be454e2ce4d1..6c3e111eda65 100644 --- a/paddlenlp/trl/dpo_criterion.py +++ b/paddlenlp/trl/dpo_criterion.py @@ -18,6 +18,7 @@ import paddle.nn as nn import paddle.nn.functional as F from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy +from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp from paddlenlp.transformers import ( AllGatherVarlenOp, @@ -28,7 +29,6 @@ ) from paddlenlp.transformers.model_outputs import CausalLMOutputWithPast from paddlenlp.utils import infohub -from paddlenlp.utils.tools import get_env_device class DPOCriterion(nn.Layer): @@ -148,7 +148,7 @@ def dpo_logps( if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: labels, sparse_tgt_idx = sequence_parallel_sparse_mask_labels(labels, 0) - hidden_states = paddle.take_along_axis(hidden_states, sparse_tgt_idx, axis=0) + hidden_states = paddle.gather(hidden_states, sparse_tgt_idx, axis=0) hidden_states = AllGatherVarlenOp.apply(hidden_states) else: labels = labels.flatten() @@ -156,8 +156,17 @@ def dpo_logps( labels = paddle.take_along_axis(labels, sparse_tgt_idx, axis=0) hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]]) - hidden_states = paddle.take_along_axis(hidden_states, sparse_tgt_idx.unsqueeze(-1), axis=0) - + hidden_states = paddle.gather(hidden_states, sparse_tgt_idx, axis=0) + elif use_fused_head_and_loss_fn: + if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = hidden_states.reshape( + [ + -1, + self.config.max_sequence_length, + hidden_states.shape[-1], + ] + ) if use_fused_head_and_loss_fn: per_token_logps = -fused_head_and_loss_fn( hidden_states, @@ -194,64 +203,65 @@ def dpo_logps( if len(response_indexs.shape) == 3: response_indexs = response_indexs[0] + + offset = 1 if self.ignore_eos_token else 0 if use_sparse_head_and_loss_fn: chosen_logps = paddle.stack( - [(per_token_logps[response_index[1] : response_index[2]]).sum() for response_index in response_indexs], + [ + ( + paddle.gather( + per_token_logps.reshape([-1]), + paddle.arange(response_index[1], response_index[2], dtype=paddle.int32), + axis=0, + ).sum() + ) + for response_index in response_indexs + ], axis=0, ) rejected_logps = paddle.stack( - [(per_token_logps[response_index[2] : response_index[3]]).sum() for response_index in response_indexs], + [ + ( + paddle.gather( + per_token_logps.reshape([-1]), + paddle.arange(response_index[2] + offset, response_index[3], dtype=paddle.int32), + axis=0, + ).sum() + ) + for response_index in response_indexs + ], axis=0, ) else: - if get_env_device() == "npu": - chosen_list = [] - for response_index in response_indexs: - begin = response_index[1] - end = response_index[2] - one_data = paddle.ones_like(per_token_logps[0]) - mask_data = paddle.zeros_like(per_token_logps[0]) - paddle.assign(one_data._slice(begin, end), mask_data._slice(begin, end)) - chosen_list.append((per_token_logps[0] * mask_data).sum()) - chosen_logps = paddle.stack(chosen_list, axis=0) - rejected_list = [] - for response_index in response_indexs: - begin = response_index[2] - if self.ignore_eos_token: - begin += 1 - end = response_index[3] - one_data = paddle.ones_like(per_token_logps[0]) - mask_data = paddle.zeros_like(per_token_logps[0]) - paddle.assign(one_data._slice(begin, end), mask_data._slice(begin, end)) - rejected_list.append((per_token_logps[0] * mask_data).sum()) - rejected_logps = paddle.stack(rejected_list, axis=0) - else: - chosen_logps = paddle.stack( - [ - (per_token_logps[response_index[0]][response_index[1] : response_index[2]]).sum() - for response_index in response_indexs - ], - axis=0, - ) - if self.ignore_eos_token: - rejected_logps = paddle.stack( - [ - (per_token_logps[response_index[0]][response_index[2] + 1 : response_index[3]]).sum() - for response_index in response_indexs - ], - axis=0, + chosen_logps = paddle.stack( + [ + ( + paddle.gather( + paddle.gather(per_token_logps, response_index[0], axis=0), + paddle.arange(response_index[1], response_index[2], dtype=paddle.int32), + axis=0, + ).sum() ) - else: - rejected_logps = paddle.stack( - [ - (per_token_logps[response_index[0]][response_index[2] : response_index[3]]).sum() - for response_index in response_indexs - ], - axis=0, + for response_index in response_indexs + ], + axis=0, + ) + rejected_logps = paddle.stack( + [ + ( + paddle.gather( + paddle.gather(per_token_logps, response_index[0], axis=0), + paddle.arange(response_index[2] + offset, response_index[3], dtype=paddle.int32), + axis=0, + ).sum() ) + for response_index in response_indexs + ], + axis=0, + ) sft_loss = -chosen_logps.sum() / (chosen_labels != 0).sum() if average_log_prob: - chosen_response_length = response_indexs[:, 2] - response_indexs[:, 1] + chosen_response_length = response_indexs[:, 2] - response_indexs[:, 1] - offset rejected_response_length = response_indexs[:, 3] - response_indexs[:, 2] chosen_logps /= chosen_response_length.astype("float32") rejected_logps /= rejected_response_length.astype("float32") From da41c4f252da63fc468628a1cba07bf6a310e750 Mon Sep 17 00:00:00 2001 From: lugimzzz <63761690+lugimzzz@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:41:11 +0800 Subject: [PATCH 03/11] [llm]add adam-mini (#9542) * add adam-mini * fix following comments --- docs/trainer.md | 3 + llm/docs/dpo.md | 1 + llm/docs/finetune.md | 1 + paddlenlp/trainer/trainer.py | 5 + paddlenlp/trainer/trainer_utils.py | 1 + paddlenlp/trainer/training_args.py | 2 + paddlenlp/utils/__init__.py | 1 + paddlenlp/utils/optimizer.py | 151 +++++++++++++++++++++++++++++ tests/fixtures/llm/adamw_mini.yaml | 35 +++++++ tests/llm/test_adamw_mini.py | 53 ++++++++++ 10 files changed, 253 insertions(+) create mode 100644 paddlenlp/utils/optimizer.py create mode 100644 tests/fixtures/llm/adamw_mini.yaml create mode 100644 tests/llm/test_adamw_mini.py diff --git a/docs/trainer.md b/docs/trainer.md index e5c33f21e848..d643c99268e4 100644 --- a/docs/trainer.md +++ b/docs/trainer.md @@ -691,6 +691,9 @@ Trainer 是一个简单,但功能完整的 Paddle 训练和评估模块,并 --optim 优化器名称,默认为adamw,(`str`, 可选,默认为 `adamw`) The optimizer to use. (default: adamw) + 可能的值为: + - `"adamw"` + - `"adamw_mini"` --report_to 日志可视化显示,默认使用visualdl可视化展示。(可选,默认为 None,展示所有) diff --git a/llm/docs/dpo.md b/llm/docs/dpo.md index 639059ddd0d0..4ca084e5834b 100644 --- a/llm/docs/dpo.md +++ b/llm/docs/dpo.md @@ -119,6 +119,7 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo - `unified_checkpoint`: 是否使用统一的 checkpoint,默认为 `True`。 - `autotuner_benchmark`: 是否启用 autotuner 基准测试,默认为 `False`。 - `benchmark`: 是否开启基准测试,默认为 `False`。 +- `optim`:默认为`adamw`,支持`adamw`, `adamw_mini`。 ### DPO 参数(DPOArguments) - `beta`: DPO 损失函数的 beta 参数,默认为 0.1。 - `simpo_gamma`: SimPO 损失函数的 gamma 参数,默认为 0.5。 diff --git a/llm/docs/finetune.md b/llm/docs/finetune.md index 9d3d8ffcfb38..3d6f2184a0ff 100644 --- a/llm/docs/finetune.md +++ b/llm/docs/finetune.md @@ -184,6 +184,7 @@ python merge_lora_params.py \ - `pipeline_parallel_degree`: 表示划分流水线的大小.(假设该参数为4, 模型12层, 则每一个 pp stage 包含3层模型) 默认值-1, 表示不启用流水线并行。 - `sharding_parallel_degree`: 表示分组参数切片的数据并行大小. 默认值1, 表示不启用分组参数切片的数据并行。 - `sharding`:是否使用 Paddle 的 Sharding 数据并行功能,用户的参数。支持 sharding `stage1`, `stage2` or `stage3`。其中`stage2``stage3`可以和`offload`组合使用。 +- `optim`:默认为`adamw`,支持`adamw`, `adamw_mini`。 diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 57c655736f25..59d74011e717 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1915,6 +1915,11 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: optimizer_cls = AdamW optimizer_kwargs.update(adam_kwargs) + elif args.optim == OptimizerNames.ADAMW_MINI: + from ..utils import AdamWMini + + optimizer_cls = AdamWMini + optimizer_kwargs.update(adam_kwargs) else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index 33ded2ce5bf6..e04f330c6050 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -317,6 +317,7 @@ class OptimizerNames(ExplicitEnum): ADAMW = "adamw" ADAFACTOR = "adafactor" + ADAMW_MINI = "adamw_mini" class ShardingOption(ExplicitEnum): diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 5d1dad82a831..6f9f501cdc8c 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1018,6 +1018,8 @@ def __post_init__(self): raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both") self.optim = OptimizerNames(self.optim) + if self.optim == OptimizerNames.ADAMW_MINI and self.tensor_parallel_degree > 1: + raise ValueError("AdamW Mini currently doesn't support tensor parallelism.") self.use_hybrid_parallel = False diff --git a/paddlenlp/utils/__init__.py b/paddlenlp/utils/__init__.py index a8c4dc487a0e..3b5950b0d701 100644 --- a/paddlenlp/utils/__init__.py +++ b/paddlenlp/utils/__init__.py @@ -21,6 +21,7 @@ from .import_utils import * from .infohub import infohub from .initializer import to +from .optimizer import * from .serialization import load_torch # hack impl for EagerParamBase to function diff --git a/paddlenlp/utils/optimizer.py b/paddlenlp/utils/optimizer.py new file mode 100644 index 000000000000..0b2904eb9e53 --- /dev/null +++ b/paddlenlp/utils/optimizer.py @@ -0,0 +1,151 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import pir +from paddle.base import core, framework +from paddle.base.framework import Variable, in_dynamic_or_pir_mode, in_pir_mode +from paddle.base.libpaddle import DataType +from paddle.optimizer.adamw import AdamW +from paddle.pir import Value + + +class AdamWMini(AdamW): + def _add_moments_pows(self, p): + acc_dtype = p.dtype + if self._is_dtype_fp16_or_bf16(acc_dtype): + acc_dtype = DataType.FLOAT32 if in_pir_mode() else paddle.float32 + + self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) + # change moment2 + self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype, shape=[1]) + try: + type = core.VarDesc.VarType.DENSE_TENSOR + except: + type = core.VarDesc.VarType.LOD_TENSOR + self._add_accumulator( + name=self._beta1_pow_acc_str, + param=p, + dtype=acc_dtype, + fill_value=0.9 if isinstance(self._beta1, (Variable, Value)) else self._beta1, + shape=[1], + type=type, + device="cpu", + ) + self._add_accumulator( + name=self._beta2_pow_acc_str, + param=p, + dtype=acc_dtype, + fill_value=0.999 if isinstance(self._beta2, (Variable, Value)) else self._beta2, + shape=[1], + type=type, + device="cpu", + ) + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, (framework.Block, pir.Block)) + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) + param = param_and_grad[0] + + # Whether we should do weight decay for the parameter. + with_decay = True + if self._apply_decay_param_fun is not None and not self._apply_decay_param_fun(param.name): + with_decay = False + + moment1 = self._get_accumulator_master(self._moment1_acc_str, param_and_grad[0]) + moment2 = self._get_accumulator_master(self._moment2_acc_str, param_and_grad[0]) + beta1_pow_acc = self._get_accumulator_master(self._beta1_pow_acc_str, param_and_grad[0]) + beta2_pow_acc = self._get_accumulator_master(self._beta2_pow_acc_str, param_and_grad[0]) + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype) + master_weight = self._master_weights[param_and_grad[0].name] if find_master else None + lr = self._create_param_lr(param_and_grad) + # create the adamw optimize op + if in_dynamic_or_pir_mode(): + lr_ratio_ = 1.0 if self._lr_ratio is None else self._lr_ratio(param_and_grad[0]) + + _beta1 = self._beta1 if not isinstance(self._beta1, Variable) else self._beta1.item(0) + _beta2 = self._beta2 if not isinstance(self._beta2, Variable) else self._beta2.item(0) + + found_inf = self._get_auxiliary_var("found_inf") if in_pir_mode() else None + self.adamw_python( + param_and_grad[0], + param_and_grad[1], + lr, + moment1, + moment2, + beta1_pow_acc, + beta2_pow_acc, + master_weight, + found_inf, + _beta1, + _beta2, + self._epsilon, + lr_ratio_, + self._weight_decay, + with_decay, + find_master, + ) + return None + else: + raise NotImplementedError("Not implemented yet.") + + def adamw_python( + self, + param, + grad, + learning_rate, + moment1, + moment2, + beta1_pow, + beta2_pow, + master_weight, + skip_update, + beta1, + beta2, + epsilon, + lr_ratio, + coeff, + with_decay, + multi_precision, + ): + if skip_update: + return + if not with_decay: + coeff = 0.0 + if not multi_precision: + master_weight = None + lr = learning_rate * lr_ratio + if master_weight is not None: + p = master_weight + else: + p = param + p *= 1.0 - lr * coeff + mom1 = moment1 + mom2 = moment2 + + mom1 = beta1 * mom1 + (1.0 - beta1) * grad + mom2 = beta2 * mom2 + (1.0 - beta2) * (grad * grad).mean() + denom = mom2.sqrt() / (1.0 - beta2_pow).sqrt() + epsilon + p += (moment1 / denom) * (-(lr / (1.0 - beta1_pow))) + if master_weight is not None: + master_weight[:] = p + param[:] = p.astype(param.dtype) + else: + param[:] = p + moment1[:] = mom1 + moment2[:] = mom2 + beta1_pow[:], beta2_pow[:] = beta1 * beta1_pow[:], beta2 * beta2_pow[:] + # 看看怎么更新 + return diff --git a/tests/fixtures/llm/adamw_mini.yaml b/tests/fixtures/llm/adamw_mini.yaml new file mode 100644 index 000000000000..6dc6e9b865b9 --- /dev/null +++ b/tests/fixtures/llm/adamw_mini.yaml @@ -0,0 +1,35 @@ +finetune: + base: + dataset_name_or_path: "./data" + per_device_train_batch_size: 4 + gradient_accumulation_steps: 4 + per_device_eval_batch_size: 8 + eval_accumulation_steps: 16 + num_train_epochs: 3 + learning_rate: 3e-05 + warmup_steps: 30 + logging_steps: 1 + evaluation_strategy: "epoch" + save_strategy: "epoch" + src_length: 1024 + max_length: 2048 + fp16: true + fp16_opt_level: "O2" + do_train: true + do_eval: true + use_flash_attention: true + disable_tqdm: true + load_best_model_at_end: true + eval_with_do_generation: false + metric_for_best_model: "accuracy" + recompute: true + refined_recompute: "flash_attn:-1" + save_total_limit: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + ignore_save_lr_and_optim: 1 + optim: "adamw_mini" + + default: + llama: + model_name_or_path: __internal_testing__/tiny-random-llama diff --git a/tests/llm/test_adamw_mini.py b/tests/llm/test_adamw_mini.py new file mode 100644 index 000000000000..383d82407a06 --- /dev/null +++ b/tests/llm/test_adamw_mini.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import sys +import unittest + +from parameterized import parameterized_class + +from tests.testing_utils import argv_context_guard, load_test_config + +from .testing_utils import LLMTest + + +@parameterized_class( + ["model_dir"], + [ + ["llama"], + ], +) +class FinetuneTest(LLMTest, unittest.TestCase): + config_path: str = "./tests/fixtures/llm/adamw_mini.yaml" + model_dir: str = None + + def setUp(self) -> None: + LLMTest.setUp(self) + + sys.path.insert(0, self.model_dir) + + def tearDown(self) -> None: + LLMTest.tearDown(self) + + def test_finetune(self): + finetune_config = load_test_config(self.config_path, "finetune", self.model_dir) + + finetune_config["dataset_name_or_path"] = self.data_dir + finetune_config["output_dir"] = self.output_dir + + with argv_context_guard(finetune_config): + from run_finetune import main + + main() From 5a94e04521444e61fb184ad14f08f7a3d6d2d81f Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Mon, 16 Dec 2024 19:16:02 +0800 Subject: [PATCH 04/11] [Release] Update version for beta3 (#9553) * Update README.md --- README.md | 2 +- llm/docs/finetune.md | 2 +- llm/run_finetune.py | 11 +++++++++++ paddlenlp/__init__.py | 2 +- setup.py | 2 +- 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index c97aa56fc9b6..990d6dac5b25 100644 --- a/README.md +++ b/README.md @@ -161,7 +161,7 @@ ### pip 安装 ```shell -pip install --upgrade paddlenlp==3.0.0b2 +pip install --upgrade paddlenlp==3.0.0b3 ``` 或者可通过以下命令安装最新 develop 分支代码: diff --git a/llm/docs/finetune.md b/llm/docs/finetune.md index 3d6f2184a0ff..233213a9b73b 100644 --- a/llm/docs/finetune.md +++ b/llm/docs/finetune.md @@ -36,7 +36,7 @@ ### 3.1 环境准备 - PaddlePaddle 3.0-beta -- PaddleNLP 3.0.0b2 +- PaddleNLP 3.0.0b3 - PaddleSlim develop git clone 代码到本地,即可开始。 diff --git a/llm/run_finetune.py b/llm/run_finetune.py index e18f25bb6cc2..a99d3bcbc224 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -79,7 +79,18 @@ flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe, Qwen2ForCausalLM, Qwen2ForCausalLMPipe] +def paddlenlp_verison_check(): + import paddlenlp + from paddlenlp.utils.tools import compare_version + + if not compare_version(paddlenlp.__version__, "3.0.0.b2"): + raise ValueError( + "This scripts require paddlenlp >= 3.0.0b3, please reinstall: pip install paddlenlp >= 3.0.0b3 " + ) + + def main(): + paddlenlp_verison_check() parser = PdArgumentParser((GenerateArgument, ModelConfig, ReftArgument, DataConfig, SFTConfig)) if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): gen_args, model_args, reft_args, data_args, training_args = parser.parse_json_file_and_cmd_lines() diff --git a/paddlenlp/__init__.py b/paddlenlp/__init__.py index 9f6d1e5953af..39ec1b3f4a2e 100644 --- a/paddlenlp/__init__.py +++ b/paddlenlp/__init__.py @@ -20,7 +20,7 @@ # this version is used for develop and test. # release version will be added fixed version by setup.py. -__version__ = "3.0.0b2.post" +__version__ = "3.0.0b3.post" if os.getenv(PADDLENLP_STABLE_VERSION): __version__ = __version__.replace(".post", "") else: diff --git a/setup.py b/setup.py index 131f23b9da77..3798dee63c2e 100644 --- a/setup.py +++ b/setup.py @@ -112,7 +112,7 @@ def show(): # only use this file to contral the version -__version__ = "3.0.0b2.post" +__version__ = "3.0.0b3.post" if os.getenv(PADDLENLP_STABLE_VERSION): __version__ = __version__.replace(".post", "") else: From dc0ca03b525028b47b54fe80bcf24f501574bc76 Mon Sep 17 00:00:00 2001 From: Weiguo Zhu Date: Mon, 16 Dec 2024 19:30:44 +0800 Subject: [PATCH 05/11] [LLM DOCs] Add deepseek models (#9643) * update --- README.md | 52 +++++++++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 990d6dac5b25..c18c6926d25b 100644 --- a/README.md +++ b/README.md @@ -71,31 +71,33 @@ * 模型参数已支持 LLaMA 系列、Baichuan 系列、Bloom 系列、ChatGLM 系列、Gemma 系列、Mistral 系列、OPT 系列和 Qwen 系列,详细列表👉【LLM】模型参数支持列表如下: -| 模型系列 | 模型名称 | -|:----------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [LLaMA](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | facebook/llama-7b, facebook/llama-13b, facebook/llama-30b, facebook/llama-65b | -| [LLama2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-2-7b, meta-llama/Llama-2-7b-chat, meta-llama/Llama-2-13b, meta-llama/Llama-2-13b-chat, meta-llama/Llama-2-70b, meta-llama/Llama-2-70b-chat | -| [LLama3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3-8B, meta-llama/Meta-Llama-3-8B-Instruct, meta-llama/Meta-Llama-3-70B, meta-llama/Meta-Llama-3-70B-Instruct | -| [LLama3.1](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3.1-8B, meta-llama/Meta-Llama-3.1-8B-Instruct, meta-llama/Meta-Llama-3.1-70B, meta-llama/Meta-Llama-3.1-70B-Instruct, meta-llama/Meta-Llama-3.1-405B, meta-llama/Meta-Llama-3.1-405B-Instruct, meta-llama/Llama-Guard-3-8B | -| [LLama3.2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-3.2-1B, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.2-3B, meta-llama/Llama-3.2-3B-Instruct, meta-llama/Llama-Guard-3-1B | -| [Baichuan](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan-7B, baichuan-inc/Baichuan-13B-Base, baichuan-inc/Baichuan-13B-Chat | -| [Baichuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan2-7B-Base, baichuan-inc/Baichuan2-7B-Chat, baichuan-inc/Baichuan2-13B-Base, baichuan-inc/Baichuan2-13B-Chat | -| [Bloom](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/bloom) | bigscience/bloom-560m, bigscience/bloom-560m-bf16, bigscience/bloom-1b1, bigscience/bloom-3b, bigscience/bloom-7b1, bigscience/bloomz-560m, bigscience/bloomz-1b1, bigscience/bloomz-3b, bigscience/bloomz-7b1-mt, bigscience/bloomz-7b1-p3, bigscience/bloomz-7b1, bellegroup/belle-7b-2m | -| [ChatGLM](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm/) | THUDM/chatglm-6b, THUDM/chatglm-6b-v1.1 | -| [ChatGLM2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm2-6b | -| [ChatGLM3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm3-6b | -| [Gemma](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/gemma) | google/gemma-7b, google/gemma-7b-it, google/gemma-2b, google/gemma-2b-it | -| [Mistral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mistral) | mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-7B-v0.1 | -| [Mixtral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mixtral) | mistralai/Mixtral-8x7B-Instruct-v0.1 | -| [OPT](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/opt) | facebook/opt-125m, facebook/opt-350m, facebook/opt-1.3b, facebook/opt-2.7b, facebook/opt-6.7b, facebook/opt-13b, facebook/opt-30b, facebook/opt-66b, facebook/opt-iml-1.3b, opt-iml-max-1.3b | -| [Qwen](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | qwen/qwen-7b, qwen/qwen-7b-chat, qwen/qwen-14b, qwen/qwen-14b-chat, qwen/qwen-72b, qwen/qwen-72b-chat, | -| [Qwen1.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen1.5-0.5B, Qwen/Qwen1.5-0.5B-Chat, Qwen/Qwen1.5-1.8B, Qwen/Qwen1.5-1.8B-Chat, Qwen/Qwen1.5-4B, Qwen/Qwen1.5-4B-Chat, Qwen/Qwen1.5-7B, Qwen/Qwen1.5-7B-Chat, Qwen/Qwen1.5-14B, Qwen/Qwen1.5-14B-Chat, Qwen/Qwen1.5-32B, Qwen/Qwen1.5-32B-Chat, Qwen/Qwen1.5-72B, Qwen/Qwen1.5-72B-Chat, Qwen/Qwen1.5-110B, Qwen/Qwen1.5-110B-Chat, Qwen/Qwen1.5-MoE-A2.7B, Qwen/Qwen1.5-MoE-A2.7B-Chat | -| [Qwen2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-0.5B, Qwen/Qwen2-0.5B-Instruct, Qwen/Qwen2-1.5B, Qwen/Qwen2-1.5B-Instruct, Qwen/Qwen2-7B, Qwen/Qwen2-7B-Instruct, Qwen/Qwen2-72B, Qwen/Qwen2-72B-Instruct, Qwen/Qwen2-57B-A14B, Qwen/Qwen2-57B-A14B-Instruct | -| [Qwen2-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-Math-1.5B, Qwen/Qwen2-Math-1.5B-Instruct, Qwen/Qwen2-Math-7B, Qwen/Qwen2-Math-7B-Instruct, Qwen/Qwen2-Math-72B, Qwen/Qwen2-Math-72B-Instruct, Qwen/Qwen2-Math-RM-72B | -| [Qwen2.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-0.5B, Qwen/Qwen2.5-0.5B-Instruct, Qwen/Qwen2.5-1.5B, Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-3B, Qwen/Qwen2.5-3B-Instruct, Qwen/Qwen2.5-7B, Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen2.5-14B, Qwen/Qwen2.5-14B-Instruct, Qwen/Qwen2.5-32B, Qwen/Qwen2.5-32B-Instruct, Qwen/Qwen2.5-72B, Qwen/Qwen2.5-72B-Instruct | -| [Qwen2.5-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Math-1.5B, Qwen/Qwen2.5-Math-1.5B-Instruct, Qwen/Qwen2.5-Math-7B, Qwen/Qwen2.5-Math-7B-Instruct, Qwen/Qwen2.5-Math-72B, Qwen/Qwen2.5-Math-72B-Instruct, Qwen/Qwen2.5-Math-RM-72B | -| [Qwen2.5-Coder](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Coder-1.5B, Qwen/Qwen2.5-Coder-1.5B-Instruct, Qwen/Qwen2.5-Coder-7B, Qwen/Qwen2.5-Coder-7B-Instruct | -| [Yuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/yuan/) | IEITYuan/Yuan2-2B, IEITYuan/Yuan2-51B, IEITYuan/Yuan2-102B | +| 模型系列 | 模型名称 | +|:-------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [LLaMA](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | facebook/llama-7b, facebook/llama-13b, facebook/llama-30b, facebook/llama-65b | +| [Llama2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-2-7b, meta-llama/Llama-2-7b-chat, meta-llama/Llama-2-13b, meta-llama/Llama-2-13b-chat, meta-llama/Llama-2-70b, meta-llama/Llama-2-70b-chat | +| [Llama3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3-8B, meta-llama/Meta-Llama-3-8B-Instruct, meta-llama/Meta-Llama-3-70B, meta-llama/Meta-Llama-3-70B-Instruct | +| [Llama3.1](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3.1-8B, meta-llama/Meta-Llama-3.1-8B-Instruct, meta-llama/Meta-Llama-3.1-70B, meta-llama/Meta-Llama-3.1-70B-Instruct, meta-llama/Meta-Llama-3.1-405B, meta-llama/Meta-Llama-3.1-405B-Instruct, meta-llama/Llama-Guard-3-8B | +| [Llama3.2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-3.2-1B, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.2-3B, meta-llama/Llama-3.2-3B-Instruct, meta-llama/Llama-Guard-3-1B | +| [Llama3.3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-3.3-70B-Instruct | +| [Baichuan](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan-7B, baichuan-inc/Baichuan-13B-Base, baichuan-inc/Baichuan-13B-Chat | +| [Baichuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan2-7B-Base, baichuan-inc/Baichuan2-7B-Chat, baichuan-inc/Baichuan2-13B-Base, baichuan-inc/Baichuan2-13B-Chat | +| [Bloom](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/bloom) | bigscience/bloom-560m, bigscience/bloom-560m-bf16, bigscience/bloom-1b1, bigscience/bloom-3b, bigscience/bloom-7b1, bigscience/bloomz-560m, bigscience/bloomz-1b1, bigscience/bloomz-3b, bigscience/bloomz-7b1-mt, bigscience/bloomz-7b1-p3, bigscience/bloomz-7b1, bellegroup/belle-7b-2m | +| [ChatGLM](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm/) | THUDM/chatglm-6b, THUDM/chatglm-6b-v1.1 | +| [ChatGLM2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm2-6b | +| [ChatGLM3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm3-6b | +| [DeepSeekV2](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V2, deepseek-ai/DeepSeek-V2-Chat, deepseek-ai/DeepSeek-V2-Lite, deepseek-ai/DeepSeek-V2-Lite-Chat, deepseek-ai/DeepSeek-Coder-V2-Base, deepseek-ai/DeepSeek-Coder-V2-Instruct, deepseek-ai/DeepSeek-Coder-V2-Lite-Base, deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct | +| [Gemma](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/gemma) | google/gemma-7b, google/gemma-7b-it, google/gemma-2b, google/gemma-2b-it | +| [Mistral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mistral) | mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-7B-v0.1 | +| [Mixtral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mixtral) | mistralai/Mixtral-8x7B-Instruct-v0.1 | +| [OPT](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/opt) | facebook/opt-125m, facebook/opt-350m, facebook/opt-1.3b, facebook/opt-2.7b, facebook/opt-6.7b, facebook/opt-13b, facebook/opt-30b, facebook/opt-66b, facebook/opt-iml-1.3b, opt-iml-max-1.3b | +| [Qwen](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | qwen/qwen-7b, qwen/qwen-7b-chat, qwen/qwen-14b, qwen/qwen-14b-chat, qwen/qwen-72b, qwen/qwen-72b-chat, | +| [Qwen1.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen1.5-0.5B, Qwen/Qwen1.5-0.5B-Chat, Qwen/Qwen1.5-1.8B, Qwen/Qwen1.5-1.8B-Chat, Qwen/Qwen1.5-4B, Qwen/Qwen1.5-4B-Chat, Qwen/Qwen1.5-7B, Qwen/Qwen1.5-7B-Chat, Qwen/Qwen1.5-14B, Qwen/Qwen1.5-14B-Chat, Qwen/Qwen1.5-32B, Qwen/Qwen1.5-32B-Chat, Qwen/Qwen1.5-72B, Qwen/Qwen1.5-72B-Chat, Qwen/Qwen1.5-110B, Qwen/Qwen1.5-110B-Chat, Qwen/Qwen1.5-MoE-A2.7B, Qwen/Qwen1.5-MoE-A2.7B-Chat | +| [Qwen2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-0.5B, Qwen/Qwen2-0.5B-Instruct, Qwen/Qwen2-1.5B, Qwen/Qwen2-1.5B-Instruct, Qwen/Qwen2-7B, Qwen/Qwen2-7B-Instruct, Qwen/Qwen2-72B, Qwen/Qwen2-72B-Instruct, Qwen/Qwen2-57B-A14B, Qwen/Qwen2-57B-A14B-Instruct | +| [Qwen2-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-Math-1.5B, Qwen/Qwen2-Math-1.5B-Instruct, Qwen/Qwen2-Math-7B, Qwen/Qwen2-Math-7B-Instruct, Qwen/Qwen2-Math-72B, Qwen/Qwen2-Math-72B-Instruct, Qwen/Qwen2-Math-RM-72B | +| [Qwen2.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-0.5B, Qwen/Qwen2.5-0.5B-Instruct, Qwen/Qwen2.5-1.5B, Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-3B, Qwen/Qwen2.5-3B-Instruct, Qwen/Qwen2.5-7B, Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen2.5-14B, Qwen/Qwen2.5-14B-Instruct, Qwen/Qwen2.5-32B, Qwen/Qwen2.5-32B-Instruct, Qwen/Qwen2.5-72B, Qwen/Qwen2.5-72B-Instruct | +| [Qwen2.5-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Math-1.5B, Qwen/Qwen2.5-Math-1.5B-Instruct, Qwen/Qwen2.5-Math-7B, Qwen/Qwen2.5-Math-7B-Instruct, Qwen/Qwen2.5-Math-72B, Qwen/Qwen2.5-Math-72B-Instruct, Qwen/Qwen2.5-Math-RM-72B | +| [Qwen2.5-Coder](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Coder-1.5B, Qwen/Qwen2.5-Coder-1.5B-Instruct, Qwen/Qwen2.5-Coder-7B, Qwen/Qwen2.5-Coder-7B-Instruct | +| [Yuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/yuan/) | IEITYuan/Yuan2-2B, IEITYuan/Yuan2-51B, IEITYuan/Yuan2-102B | * 4D 并行和算子优化已支持 LLaMA 系列、Baichuan 系列、Bloom 系列、ChatGLM 系列、Gemma 系列、Mistral 系列、OPT 系列和 Qwen 系列,【LLM】模型4D 并行和算子支持列表如下: From 9eb3cfeafc03e291f5ab1de3059df25dd6a869f5 Mon Sep 17 00:00:00 2001 From: Weiguo Zhu Date: Tue, 17 Dec 2024 14:04:25 +0800 Subject: [PATCH 06/11] [Tokenizer] Fix tokenizer of llama3.3 (#9641) * fix tokenizer of llama3 and add test case * fix paddle.where --- paddlenlp/transformers/llama/modeling.py | 3 ++- paddlenlp/transformers/llama/tokenizer.py | 6 +++-- tests/transformers/llama/test_tokenizer.py | 26 ++++++++++++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 8bf0d5938902..80781225a36c 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1601,7 +1601,8 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values expanded_attn_mask = expanded_attn_mask.astype(dtype) expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) else: - expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min) + expanded_attn_mask = expanded_attn_mask.astype(dtype) return expanded_attn_mask @paddle.jit.not_to_static diff --git a/paddlenlp/transformers/llama/tokenizer.py b/paddlenlp/transformers/llama/tokenizer.py index 8260e2b1239f..cb9ccc423b4f 100644 --- a/paddlenlp/transformers/llama/tokenizer.py +++ b/paddlenlp/transformers/llama/tokenizer.py @@ -340,9 +340,11 @@ def __init__( self.eos_token = ENDOFTEXT self.bos_token_id = self.bod_id self.eos_token_id = self.eod_id - self.pad_token = self.convert_ids_to_tokens(self.eos_token_id) + if "pad_token" not in kwargs: + self.pad_token = self.convert_ids_to_tokens(self.eos_token_id) + kwargs["pad_token"] = self.pad_token - super().__init__(pad_token=self.pad_token, **kwargs) + super().__init__(**kwargs) def __len__(self) -> int: return self.tokenizer.n_vocab diff --git a/tests/transformers/llama/test_tokenizer.py b/tests/transformers/llama/test_tokenizer.py index 940548a7a950..fa1ee6eaf84e 100644 --- a/tests/transformers/llama/test_tokenizer.py +++ b/tests/transformers/llama/test_tokenizer.py @@ -17,6 +17,8 @@ import tempfile import unittest +from parameterized import parameterized_class + from paddlenlp.transformers.auto.tokenizer import AutoTokenizer from paddlenlp.transformers.llama.tokenizer import LlamaTokenizer from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer @@ -213,6 +215,30 @@ def test_pretrained_model_lists(self): self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_resource_files_map.values())[0]), 1) +@parameterized_class( + ["model_name_or_path"], + [ + ["facebook/llama-7b"], + ["meta-llama/Meta-Llama-3.1-8B"], + ["meta-llama/Llama-3.2-1B"], + ["meta-llama/Llama-3.3-70B-Instruct"], + ], +) +class LlamaTokenizationLoadTest(unittest.TestCase): + model_name_or_path: str = None + + def get_tokenizer(self, **kwargs) -> PretrainedTokenizer: + tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, **kwargs) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.unk_token + return tokenizer + + def test_load_tokenizer(self): + tokenizer = self.get_tokenizer() + text = "lower newer" + tokenizer.tokenize(text, add_prefix_space=True) + + class TikTokenIntegrationTests(unittest.TestCase): """ A class that regroups important test to make sure that we properly handle the special tokens. From bb9793130eb2c88c2791f10924a79c7c1ec81a0c Mon Sep 17 00:00:00 2001 From: waliwali777 Date: Tue, 17 Dec 2024 15:55:45 +0800 Subject: [PATCH 07/11] add test (#9621) --- scripts/distribute/ci_case_auto.sh | 152 +++++++++++++++-------------- 1 file changed, 78 insertions(+), 74 deletions(-) diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 2dc44fb57cec..634f226ff449 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -667,80 +667,84 @@ function llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP() { case_log_dir="output/$task_name""_log" for to_static in "0" "1"; do - rm -rf $case_out_dir - rm -rf $case_log_dir - python -u -m paddle.distributed.launch \ - --gpus "0,1,2,3" \ - --log_dir $case_log_dir \ - run_pretrain_auto.py \ - --model_type "llama" \ - --model_name_or_path "facebook/llama-7b" \ - --tokenizer_name_or_path "facebook/llama-7b" \ - --input_dir "./data" \ - --output_dir $case_out_dir \ - --split 949,50,1 \ - --weight_decay 0.01 \ - --warmup_ratio 0.01 \ - --max_grad_norm 0.0 \ - --learning_rate 3e-05 \ - --min_learning_rate 3e-06 \ - --max_steps 10 \ - --logging_steps 10 \ - --eval_steps 1000 \ - --save_steps 50000 \ - --continue_training 0 \ - --do_train true \ - --do_eval false \ - --do_predict false \ - --disable_tqdm true \ - --skip_profile_timer true \ - --save_total_limit 2 \ - --device gpu \ - --disable_tqdm true \ - --dataloader_num_workers 1 \ - --enable_auto_parallel 1 \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 1 \ - --per_device_eval_batch_size 2 \ - --recompute false \ - --bf16 1\ - --fp16_opt_level "O2" \ - --amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \ - --amp_custom_white_list "lookup_table" "lookup_table_v2" \ - --amp_master_grad 1 \ - --fuse_attention_ffn false \ - --fuse_attention_qkv false \ - --fuse_sequence_parallel_allreduce false \ - --use_flash_attention 0 \ - --use_fused_rope false \ - --use_fused_rms_norm 0 \ - --max_seq_length 4096 \ - --sep_parallel_degree 1 \ - --sequence_parallel true \ - --pipeline_parallel_degree 1 \ - --sharding_parallel_degree 1 \ - --tensor_parallel_degree 2 \ - --virtual_pp_degree 1 \ - --sharding "" \ - --to_static ${to_static} \ - --num_hidden_layers 4 \ - >>${log_path}/$FUNCNAME 2>&1 - loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` - loss_md5=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'` - ips=-1 - mem=-1 - echo "result: to_static=$to_static loss=$loss ips=$ips mem=$mem" - loss_base=9.16783295 - loss_md5_base=8ea72495fba4e1b9ba004b4431e27218 - if [ $IS_A100 -ne 0 ] && [ $to_static -eq 0 ];then - loss_base=9.37966919 - elif [ $IS_A100 -ne 0 ] && [ $to_static -eq 1 ];then - loss_base=9.38012543 - fi - ips_base=-1 - mem_base=-1 - check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} - # check_md5_result $FUNCNAME ${loss_md5_base} ${loss_md5} + for use_recompute in "1" "0"; do + rm -rf $case_out_dir + rm -rf $case_log_dir + python -u -m paddle.distributed.launch \ + --gpus "0,1,2,3" \ + --log_dir $case_log_dir \ + run_pretrain_auto.py \ + --model_type "llama" \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir $case_out_dir \ + --split 949,50,1 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --max_grad_norm 0.0 \ + --learning_rate 3e-05 \ + --min_learning_rate 3e-06 \ + --max_steps 10 \ + --logging_steps 10 \ + --eval_steps 1000 \ + --save_steps 50000 \ + --continue_training 0 \ + --do_train true \ + --do_eval false \ + --do_predict false \ + --disable_tqdm true \ + --skip_profile_timer true \ + --save_total_limit 2 \ + --device gpu \ + --disable_tqdm true \ + --dataloader_num_workers 1 \ + --enable_auto_parallel 1 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --per_device_eval_batch_size 2 \ + --recompute ${use_recompute} \ + --bf16 1\ + --fp16_opt_level "O2" \ + --amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" \ + --amp_master_grad 1 \ + --fuse_attention_ffn false \ + --fuse_attention_qkv false \ + --fuse_sequence_parallel_allreduce false \ + --use_flash_attention 0 \ + --use_fused_rope false \ + --use_fused_rms_norm 0 \ + --max_seq_length 4096 \ + --sep_parallel_degree 1 \ + --sequence_parallel true \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --tensor_parallel_degree 2 \ + --virtual_pp_degree 1 \ + --sharding "" \ + --to_static ${to_static} \ + --num_hidden_layers 4 \ + >>${log_path}/$FUNCNAME 2>&1 + loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` + loss_md5=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'` + ips=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'interval_tokens_per_second_per_device: ' '{print $2}' | awk -F ',' '{print $1}'` + mem=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'` + echo "result: to_static=$to_static loss=$loss ips=$ips mem=$mem" + loss_base=9.16783295 + loss_md5_base=8ea72495fba4e1b9ba004b4431e27218 + if [ $IS_A100 -ne 0 ] && [ $to_static -eq 0 ];then + loss_base=9.37966919 + elif [ $IS_A100 -ne 0 ] && [ $to_static -eq 1 ];then + loss_base=9.38012543 + fi + ips=-1 + mem=-1 + ips_base=-1 + mem_base=-1 + check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} + # check_md5_result $FUNCNAME ${loss_md5_base} ${loss_md5} + done done echo "=========== $FUNCNAME run end ===========" } From 12fba788c0fe39ed9414b29c7e001651bc47dab5 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Tue, 17 Dec 2024 15:56:59 +0800 Subject: [PATCH 08/11] Update README.md for 3.0 beta3(#9644) --- README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c18c6926d25b..d2e630bf81f2 100644 --- a/README.md +++ b/README.md @@ -31,15 +31,19 @@ PaddlePaddle%2FPaddleNLP | Trendshift ## News 📢 +* **2024.12.16 [PaddleNLP v3.0 Beta3](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v3.0.0-beta3)**:大模型功能全新升级,新增了Llama-3.2、DeepSeekV2模型,升级了TokenizerFast,快速分词,重构了SFTTrainer,一键开启SFT训练。此外,PaddleNLP还支持了优化器状态的卸载和重载功能,实现了精细化的重新计算,训练性能提升7%。在Unified Checkpoint方面,进一步优化了异步保存逻辑,新增Checkpoint压缩功能,可节省78.5%存储空间。 +最后,在大模型推理方面,升级Append Attention,支持了FP8量化,支持投机解码。 * **2024.12.13 📚《飞桨大模型套件 Unified Checkpoint 技术》**,加速模型存储95%,节省空间78%。支持全分布式策略调整自适应转换,提升模型训练的灵活性与可扩展性。训练-压缩-推理统一存储协议,无需手动转换提升全流程体验。Checkpoint 无损压缩结合异步保存,实现秒级存储并降低模型存储成本。适用于智能制造、指挥交通、医疗健康、金融服务等产业实际场景。12月24日(周二)19:00直播为您详细解读该技术如何优化大模型训练流程。报名链接:https://www.wjx.top/vm/huZkHn9.aspx?udsid=787976 * **2024.11.28 📚《FlashRAG-Paddle | 基于 PaddleNLP 的高效开发与评测 RAG 框架》**,为文本更快更好构建准确嵌入表示、加速推理生成速度。PaddleNLP 支持超大 Batch 嵌入表示学习与多硬件高性能推理,涵盖 INT8/INT4量化技术及多种高效注意力机制优化与 TensorCore 深度优化。内置全环节算子融合技术,使得 FlashRAG 推理性能相比 transformers 动态图提升70%以上,结合检索增强知识输出结果更加准确,带来敏捷高效的使用体验。直播时间:12月3日(周二)19:00。报名链接:https://www.wjx.top/vm/eaBa1vA.aspx?udsid=682361 -* **2024.08.08 📚《飞桨产业级大语言模型开发利器 PaddleNLP 3.0 重磅发布》**,训压推全流程贯通,主流模型全覆盖。大模型自动并行,千亿模型训推全流程开箱即用。提供产业级高性能精调与对齐解决方案,压缩推理领先,多硬件适配。覆盖产业级智能助手、内容创作、知识问答、关键信息抽取等应用场景。直播时间:8月22日(周四)19:00。报名链接:https://www.wjx.top/vm/Y2f7FFY.aspx?udsid=143844 -
点击展开
+
点击展开
+ +* **2024.08.08 📚《飞桨产业级大语言模型开发利器 PaddleNLP 3.0 重磅发布》**,训压推全流程贯通,主流模型全覆盖。大模型自动并行,千亿模型训推全流程开箱即用。提供产业级高性能精调与对齐解决方案,压缩推理领先,多硬件适配。覆盖产业级智能助手、内容创作、知识问答、关键信息抽取等应用场景。直播时间:8月22日(周四)19:00。报名链接:https://www.wjx.top/vm/Y2f7FFY.aspx?udsid=143844 + * **2024.06.27 [PaddleNLP v3.0 Beta](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v3.0.0-beta0)**:拥抱大模型,体验全升级。统一大模型套件,实现国产计算芯片全流程接入;全面支持飞桨4D 并行配置、高效精调策略、高效对齐算法、高性能推理等大模型产业级应用流程;自研极致收敛的 RsLoRA+算法、自动扩缩容存储机制 Unified Checkpoint 和通用化支持的 FastFFN、FusedQKV 助力大模型训推;主流模型持续支持更新,提供高效解决方案。 * **2024.04.24 [PaddleNLP v2.8](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v2.8.0)**:自研极致收敛的 RsLoRA+算法,大幅提升 PEFT 训练收敛速度以及训练效果;引入高性能生成加速到 RLHF PPO 算法,打破 PPO 训练中生成速度瓶颈,PPO 训练性能大幅领先。通用化支持 FastFFN、FusedQKV 等多个大模型训练性能优化方式,大模型训练更快、更稳定。 From 4451c0587932d9b0ade7129210c39eb6c743a070 Mon Sep 17 00:00:00 2001 From: waliwali777 Date: Tue, 17 Dec 2024 15:57:42 +0800 Subject: [PATCH 09/11] Add replace_with_parallel_cross_entropy flag (#9579) * Add replace_with_parallel_cross_entropy flag * add a100 loss_base * fix --- paddlenlp/trainer/auto_training_args.py | 5 + paddlenlp/trainer/training_args.py | 3 + scripts/distribute/ci_case_auto.sh | 170 ++++++++++++++---------- tests/trainer/test_auto_argparser.py | 1 + 4 files changed, 106 insertions(+), 73 deletions(-) diff --git a/paddlenlp/trainer/auto_training_args.py b/paddlenlp/trainer/auto_training_args.py index ee0a5c6c503e..eaa394b1c4a2 100644 --- a/paddlenlp/trainer/auto_training_args.py +++ b/paddlenlp/trainer/auto_training_args.py @@ -14,6 +14,7 @@ from dataclasses import dataclass, field +from .trainer_utils import split_parallel_config from .training_args import TrainingArguments from .utils import add_start_docstrings @@ -68,3 +69,7 @@ def __post_init__(self): if self.fused_linear: fused_passes.enable = True fused_passes.fused_passes_list.append("fused_gemm_epilogue_pass") + + mp_configs = split_parallel_config(self.tensor_parallel_config) + if "replace_with_parallel_cross_entropy" in mp_configs: + self.strategy.mp_optimization.replace_with_parallel_cross_entropy = True diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 6f9f501cdc8c..80813843ecb7 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -266,6 +266,7 @@ class TrainingArguments: sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False. sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False. sync_moment, in optimizer step, use broadcast to sync momentums those attr 'is_distributed' is False. + replace_with_parallel_cross_entropy, it replaces 'cross_entropy_with_softmax' OP with 'c_softmax_with_cross_entropy' OP in PIR static graph, which can improve model parallel performance. pipeline_parallel_config (`str`, *optional*)( Some additional config it highly affect the useage of pipeline parallel, we provide some option to config it. following config is support: @@ -681,6 +682,7 @@ class TrainingArguments: "sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.\n" "sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.\n" "sync_moment, in optimizer step, use broadcast to sync momentums those attr 'is_distributed' is False.\n" + "replace_with_parallel_cross_entropy, it replaces 'cross_entropy_with_softmax' OP with 'c_softmax_with_cross_entropy' OP in PIR static graph, which can improve model parallel performance.\n" ) }, ) @@ -1567,6 +1569,7 @@ def is_segment_parallel_supported(): "enable_delay_scale_loss", # "enable_mp_skip_c_identity", # "enable_mp_fused_linear_param_grad_add", + "replace_with_parallel_cross_entropy", ]: raise ValueError( f"Found unknown tensor parallell config {x}, " diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 634f226ff449..19c1a40da3d0 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -767,80 +767,104 @@ function llama_pir_auto_fuse_ffn_attention_qkv_MP2() { auto_task_name="llama_pir_auto_fuse_ffn_attention_qkv_MP2" auto_case_out_dir="auto_output/$auto_task_name" auto_case_log_dir="auto_output/$auto_task_name""_log" - rm -rf $auto_case_out_dir - rm -rf $auto_case_log_dir - - python -u -m paddle.distributed.launch \ - --gpus "0,1" \ - --log_dir $auto_case_log_dir \ - run_pretrain_auto.py \ - --model_name_or_path "facebook/llama-7b" \ - --tokenizer_name_or_path "facebook/llama-7b" \ - --input_dir "./data" \ - --output_dir $auto_case_out_dir \ - --split 949,50,1 \ - --weight_decay 0.01 \ - --warmup_ratio 0.01 \ - --warmup_steps 30 \ - --max_grad_norm 0.0 \ - --learning_rate 3e-05 \ - --min_learning_rate 3e-06 \ - --max_steps 5 \ - --logging_steps 1 \ - --eval_steps 1000 \ - --save_steps 3 \ - --continue_training 0 \ - --do_train true \ - --do_eval false \ - --do_predict false \ - --disable_tqdm true \ - --skip_profile_timer true \ - --save_total_limit 2 \ - --device gpu \ - --disable_tqdm true \ - --dataloader_num_workers 1 \ - --distributed_dataloader 0 \ - --enable_auto_parallel 1 \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 1 \ - --per_device_eval_batch_size 2 \ - --recompute false \ - --recompute_use_reentrant true \ - --recompute_granularity full \ - --pp_recompute_interval 0 \ - --bf16 0 \ - --fp16_opt_level "O2" \ - --amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \ - --amp_custom_white_list "lookup_table" "lookup_table_v2" \ - --amp_master_grad false \ - --fuse_attention_ffn false \ - --fuse_attention_qkv false \ - --use_flash_attention false \ - --use_fused_rope true \ - --use_fused_rms_norm true \ - --max_seq_length 4096 \ - --sequence_parallel false \ - --pipeline_parallel_degree 1 \ - --sharding_parallel_degree 1 \ - --tensor_parallel_degree 2 \ - --virtual_pp_degree 1 \ - --pipeline_schedule_mode "VPP" \ - --sharding "" \ - --to_static 1 \ - --num_hidden_layers 2 \ - >>${log_path}/$FUNCNAME 2>&1 + + tp_configs=( + "--tensor_parallel_config replace_with_parallel_cross_entropy" + " " + ) + for tp_config in "${tp_configs[@]}"; do + rm -rf $auto_case_out_dir + rm -rf $auto_case_log_dir + python -u -m paddle.distributed.launch \ + --gpus "0,1" \ + --log_dir $auto_case_log_dir \ + run_pretrain_auto.py \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir $auto_case_out_dir \ + --split 949,50,1 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --warmup_steps 30 \ + --max_grad_norm 0.0 \ + --learning_rate 3e-05 \ + --min_learning_rate 3e-06 \ + --max_steps 10 \ + --logging_steps 1 \ + --eval_steps 1000 \ + --save_steps 3 \ + --continue_training 0 \ + --do_train true \ + --do_eval false \ + --do_predict false \ + --disable_tqdm true \ + --skip_profile_timer true \ + --save_total_limit 2 \ + --device gpu \ + --disable_tqdm true \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --enable_auto_parallel 1 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --per_device_eval_batch_size 2 \ + --recompute false \ + --recompute_use_reentrant true \ + --recompute_granularity full \ + --pp_recompute_interval 0 \ + --bf16 0 \ + --fp16_opt_level "O2" \ + --amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" \ + --amp_master_grad false \ + --fuse_attention_ffn false \ + --fuse_attention_qkv false \ + --use_flash_attention false \ + --use_fused_rope true \ + --use_fused_rms_norm true \ + --max_seq_length 4096 \ + --sequence_parallel false \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --tensor_parallel_degree 2 \ + ${tp_config} \ + --virtual_pp_degree 1 \ + --pipeline_schedule_mode "VPP" \ + --sharding "" \ + --to_static 1 \ + --num_hidden_layers 2 \ + >>${log_path}/$FUNCNAME 2>&1 - auto_loss=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 5' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` - auto_ips=-1 - auto_mem=-1 - echo "auto result: step 5 loss=$auto_loss ips=$auto_ips mem=$auto_mem" - loss_base=10.21024895 - ips_base=-1 - mem_base=-1 - if [ $IS_A100 -ne 0 ];then - loss_base=10.27925682 - fi - check_result $FUNCNAME ${loss_base} ${auto_loss} ${ips_base} ${auto_ips} ${mem_base} ${auto_mem} + auto_loss_2=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 2' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` + loss_md5_2=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 2' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'` + auto_ips_2=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 2' | awk -F 'interval_tokens_per_second_per_device: ' '{print $2}' | awk -F ',' '{print $1}'` + auto_mem_2=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 2' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'` + echo "auto result: step 2 loss=$auto_loss_2 ips=$auto_ips_2 mem=$auto_mem_2" + auto_loss_10=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` + loss_md5_10=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'` + auto_ips_10=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'interval_tokens_per_second_per_device: ' '{print $2}' | awk -F ',' '{print $1}'` + auto_mem_10=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'` + echo "auto result: step 10 loss=$auto_loss_10 ips=$auto_ips_10 mem=$auto_mem_10" + if [[ $tp_config =~ "replace_with_parallel_cross_entropy" ]];then + # This optimization may result in a discrepancy in accuracy. + loss_base_2=10.53477287 + loss_base_10=9.4961338 + else + loss_base_2=10.53477192 + loss_base_10=9.4961338 + fi + auto_ips=-1 + auto_mem=-1 + ips_base=-1 + mem_base=-1 + if [ $IS_A100 -ne 0 ];then + loss_base_2=10.58283806 + loss_base_10=10.58283806 + fi + check_result $FUNCNAME ${loss_base_2} ${auto_loss_2} ${ips_base} ${auto_ips} ${mem_base} ${auto_mem} + check_result $FUNCNAME ${loss_base_10} ${auto_loss_10} ${ips_base} ${auto_ips} ${mem_base} ${auto_mem} + done export FLAGS_enable_fused_ffn_qkv_pass=0 echo "=========== $FUNCNAME run end ===========" } diff --git a/tests/trainer/test_auto_argparser.py b/tests/trainer/test_auto_argparser.py index 9d5b311c41e9..d10d8a786d70 100644 --- a/tests/trainer/test_auto_argparser.py +++ b/tests/trainer/test_auto_argparser.py @@ -66,6 +66,7 @@ class AutoArgparserTest(unittest.TestCase): "num_cycles": 0.5, "num_train_epochs": 3.0, "output_dir": "./checkpoints/llama2_pretrain_ckpts", + "tensor_parallel_config": "replace_with_parallel_cross_entropy", } def test_parse_cmd_lines(self): From da7a7d296f2f67ce2ccf77d2959a0b717d421689 Mon Sep 17 00:00:00 2001 From: zhengzhonghui Date: Tue, 17 Dec 2024 16:07:18 +0800 Subject: [PATCH 10/11] [AutoParallel] change loss_base after dropout support spmd (#9647) * [AutoParallel] change loss_base after dropout support spmd * [AutoParallel] change loss_base after dropout support spmd --- scripts/distribute/ci_case_auto.sh | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 19c1a40da3d0..65f7c439299f 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -1778,11 +1778,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2() { ips=-1 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5" - loss_base=10.59368134 + loss_base=10.59486389 # output of dropout is different after supporting spmd ips_base=-1 mem_base=-1 if [ $IS_A100 -ne 0 ];then - loss_base=10.60190201 + loss_base=10.60063553 # after add dropout spmd fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" @@ -1850,11 +1850,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2() { ips=-1 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5" - loss_base=10.5913763 + loss_base=10.58862114 # output of dropout is different after supporting spmd ips_base=-1 mem_base=-1 if [ $IS_A100 -ne 0 ];then - loss_base=10.5915575 + loss_base=10.59354877 # after add dropout spmd fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" @@ -1923,11 +1923,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2-PP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5" # loss_base=10.59993172 # note: need to debug - loss_base=10.58103752 + loss_base=10.58122158 # output of dropout is different after supporting spmd ips_base=-1 mem_base=-1 if [ $IS_A100 -ne 0 ];then - loss_base=10.58719826 + loss_base=10.58605194 # after add dropout spmd fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" @@ -1996,12 +1996,11 @@ function llm_gpt_dygraph_auto_bs8_fp16_DP2-MP2-PP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5" # loss_base=10.58456802 # note: need to debug - loss_base=10.58146572 + loss_base=10.58163357 ips_base=-1 mem_base=-1 if [ $IS_A100 -ne 0 ];then - # loss_base=10.58141422 # note: need to debug - loss_base=10.58743668 + loss_base=10.58635044 # after add dropout spmd fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" From 2231feb06840b64afa10ae76f131e87287cefa1a Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Tue, 17 Dec 2024 20:51:28 +0800 Subject: [PATCH 11/11] [Embedding] Add embedding training (#9508) * add Qwen2SentenceEmbedding * add embedding trainer --------- Co-authored-by: DrownFish19 --- llm/config/qwen/emb_argument.json | 36 +++ llm/run_embedding.py | 288 +++++++++++++++++++++++ llm/utils/argument.py | 52 ++++ paddlenlp/data/data_collator.py | 124 ++++++++++ paddlenlp/datasets/__init__.py | 1 + paddlenlp/datasets/embedding_dataset.py | 252 ++++++++++++++++++++ paddlenlp/transformers/qwen2/modeling.py | 83 ++++++- 7 files changed, 835 insertions(+), 1 deletion(-) create mode 100644 llm/config/qwen/emb_argument.json create mode 100644 llm/run_embedding.py create mode 100644 paddlenlp/datasets/embedding_dataset.py diff --git a/llm/config/qwen/emb_argument.json b/llm/config/qwen/emb_argument.json new file mode 100644 index 000000000000..d8c6aeeb7f6e --- /dev/null +++ b/llm/config/qwen/emb_argument.json @@ -0,0 +1,36 @@ +{ + "model_name_or_path": "Qwen/Qwen2-0.5B", + "dataset_name_or_path": "./dureader_data", + "output_dir": "./checkpoints/sft_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 1, + "eval_accumulation_steps": 1, + "max_steps": 2000, + "learning_rate": 3e-5, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "no", + "save_strategy": "epoch", + "max_query_len": 512, + "max_passage_len": 512, + "group_size": 4, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": false, + "disable_tqdm": true, + "load_best_model_at_end": false, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "sharding": "stage1", + "zero_padding": false, + "unified_checkpoint": true, + "use_flash_attention": true, + "amp_custom_black_list": "elementwise_div", + "release_grads": true +} diff --git a/llm/run_embedding.py b/llm/run_embedding.py new file mode 100644 index 000000000000..e598f24839cf --- /dev/null +++ b/llm/run_embedding.py @@ -0,0 +1,288 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import inspect +import os +import sys + +import paddle +from utils.argument import EmbeddingArgument + +from paddlenlp.data import DataCollatorForEmbedding +from paddlenlp.datasets import EmbeddingIterableDataset, load_dataset +from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed +from paddlenlp.trainer.trainer_callback import TrainerState +from paddlenlp.transformers import ( + AutoConfig, + AutoTokenizer, + Qwen2Config, + Qwen2SentenceEmbedding, +) +from paddlenlp.transformers.configuration_utils import LlmMetaConfig +from paddlenlp.transformers.refined_recompute import update_refined_recompute +from paddlenlp.trl import DataConfig, EmbeddingTrainer, ModelConfig, SFTConfig +from paddlenlp.trl.llm_utils import compute_metrics, init_chat_template +from paddlenlp.utils.log import logger + +# Fine-tune Environment Variables to support sharding stage1 overlap optimization. +os.environ["USE_CASUAL_MASK"] = "False" + + +def main(): + parser = PdArgumentParser((ModelConfig, DataConfig, SFTConfig, EmbeddingArgument)) + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args, embedding_args = parser.parse_json_file_and_cmd_lines() + else: + model_args, data_args, training_args, embedding_args = parser.parse_args_into_dataclasses() + + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + + # Setup GPU & distributed training + paddle.set_device(training_args.device) + set_seed(seed=training_args.seed) + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" + ) + + if training_args.pipeline_parallel_degree > 1: + raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Load model + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + elif training_args.bf16: + dtype = "bfloat16" + else: + raise ValueError("Please specific dtype: --fp16 or --bf16") + else: + dtype = "float32" + + model_config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + dtype=dtype, + from_aistudio=model_args.from_aistudio, + ) + assert isinstance(model_config, Qwen2Config), "Now only qwen2 supported" + + LlmMetaConfig.set_llm_config(model_config, training_args) + model_config.refined_recompute = update_refined_recompute(training_args.refined_recompute) + model_config.use_fast_layer_norm = model_args.use_fast_layer_norm + + # Config for model using dropout, such as GPT. + if hasattr(model_config, "hidden_dropout_prob"): + model_config.hidden_dropout_prob = model_args.hidden_dropout_prob + if hasattr(model_config, "attention_probs_dropout_prob"): + model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob + if hasattr(model_config, "ignore_index"): + model_config.ignore_index = -100 + + if model_args.fuse_attention_qkv is not None: + model_config.fuse_attention_qkv = model_args.fuse_attention_qkv + if model_args.fuse_attention_ffn is not None: + model_config.fuse_attention_ffn = model_args.fuse_attention_ffn + + model_config.seq_length = data_args.max_length + model_config.embedding_negatives_cross_device = embedding_args.embedding_negatives_cross_device + logger.info(f"Final model config: {model_config}") + + model_class = Qwen2SentenceEmbedding + + if model_args.continue_training and not training_args.autotuner_benchmark: + model = model_class.from_pretrained( + model_args.model_name_or_path, + config=model_config, + from_aistudio=model_args.from_aistudio, + ) + else: + model = model_class.from_config(model_config, dtype=dtype) + + if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention): + logger.warning("`flash_mask` must use with zero padding and flash attention.") + data_args.zero_padding = True + model.config.use_flash_attention = True + + # Load tokenizer & dataset + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, from_aistudio=model_args.from_aistudio) + + # init chat_template for tokenizer + init_chat_template(tokenizer, model_args.model_name_or_path, data_args.chat_template) + + # if using chat_template, data_args.eval_with_do_generation must be false + if tokenizer.chat_template is not None: + data_args.eval_with_do_generation = False + + if training_args.do_eval: + logger.warning("Warning: 'do_eval' is set to True, but will be set to False for Embedding training currently.") + training_args.do_eval = False + training_args.evaluation_strategy = "no" + + if data_args.dataset_name_or_path is None: + raise ValueError(f"Please specific dataset name or path (got {data_args.dataset_name_or_path})") + elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train.json")) or os.path.exists( + os.path.join(data_args.dataset_name_or_path, "dev.json") + ): + if training_args.do_train: + train_ds = load_dataset( + "json", + data_files=os.path.join(data_args.dataset_name_or_path, "train.json"), + lazy=data_args.lazy, + )[0] + else: + train_ds = None + if training_args.do_eval: + dev_ds = load_dataset( + "json", + data_files=os.path.join(data_args.dataset_name_or_path, "dev.json"), + lazy=data_args.lazy, + )[0] + else: + dev_ds = None + + elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train")) or os.path.exists( + os.path.join(data_args.dataset_name_or_path, "dev") + ): + import glob + + if training_args.do_train: + train_ds = load_dataset( + "json", + data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "train", "*.json")), + lazy=data_args.lazy, + )[0] + else: + train_ds = None + if training_args.do_eval: + dev_ds = load_dataset( + "json", + data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "dev", "*.json")), + lazy=data_args.lazy, + )[0] + else: + dev_ds = None + + else: + if training_args.do_train: + train_ds = load_dataset(data_args.dataset_name_or_path, splits=["train"])[0] + else: + train_ds = None + if training_args.do_eval: + dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["dev"])[0] + else: + dev_ds = None + + # TODO(ZHUI & sijunhe): Temporary implementation. Generalize this logic and move to Trainer later. + if training_args.resume_from_checkpoint is not None and data_args.lazy: + logger.info( + f"Loading from '{training_args.resume_from_checkpoint}' with `lazy=True`, manually skipping dataset and setting `ignore_data_skip` to True." + ) + training_args.ignore_data_skip = True + state = TrainerState.load_from_json(os.path.join(training_args.resume_from_checkpoint, "trainer_state.json")) + if state.trial_params is not None and "zero_padding_global_step" in state.trial_params: + consumed_samples = state.trial_params["zero_padding_global_step"] + else: + consumed_samples = ( + state.global_step + * training_args.per_device_train_batch_size + * training_args.gradient_accumulation_steps + * training_args.dataset_world_size + ) + logger.info( + f"Skipping the first {consumed_samples} samples to warmup the dataset from checkpoint '{training_args.resume_from_checkpoint}'." + ) + train_ds = train_ds.skip(consumed_samples) + + if train_ds is not None: + train_ds = EmbeddingIterableDataset( + train_ds, + tokenizer, + max_query_len=embedding_args.max_query_len, + max_passage_len=embedding_args.max_passage_len, + group_size=embedding_args.group_size, + query_template=embedding_args.query_template, + passage_template=embedding_args.passage_template, + ) + + if dev_ds is not None: + dev_ds = EmbeddingIterableDataset( + dev_ds, + tokenizer, + max_query_len=embedding_args.max_query_len, + max_passage_len=embedding_args.max_passage_len, + group_size=embedding_args.group_size, + query_template=embedding_args.query_template, + passage_template=embedding_args.passage_template, + ) + + # Create trainer + if data_args.pad_to_max_length: + padding = "max_length" + else: + padding = True + + data_collator_fn = DataCollatorForEmbedding( + tokenizer=tokenizer, + max_query_len=embedding_args.max_query_len, + padding=padding, + max_passage_len=embedding_args.max_passage_len, + return_tensors="np", + return_attention_mask=not model_args.flash_mask, + pad_to_multiple_of=data_args.pad_to_multiple_of, + ) + trainer = EmbeddingTrainer( + model=model, + model_args=embedding_args, + args=training_args, + train_dataset=train_ds, + eval_dataset=dev_ds, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + data_collator=data_collator_fn, + ) + trainable_parameters = [p for p in model.parameters() if not p.stop_gradient] + trainer.set_optimizer_grouped_parameters(trainable_parameters) + + # Train + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1) + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + + # Evaluation dev set + if training_args.do_eval: + logger.info("*** Evaluate result after train ***") + eval_result = trainer.evaluate(dev_ds) + trainer.log_metrics("eval", eval_result) + + +if __name__ == "__main__": + main() diff --git a/llm/utils/argument.py b/llm/utils/argument.py index 812293f1ab8f..99df142e826e 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field +from typing import List, Optional @dataclass @@ -36,3 +37,54 @@ class GenerateArgument: top_p: float = field( default=1.0, metadata={"help": "The cumulative probability for top-p-filtering in the sampling strategy."} ) + + +@dataclass +class EmbeddingArgument: + max_query_len: int = field( + default=1, + metadata={ + "help": "The number of highest probability tokens to keep for top-k-filtering in the sampling strategy" + }, + ) + max_passage_len: int = field( + default=1.0, metadata={"help": "The cumulative probability for top-p-filtering in the sampling strategy."} + ) + group_size: int = field( + default=8, + metadata={ + "help": ( + "Number of total positive and negative samples associated with " "each query for embedding training." + ) + }, + ) + query_template: str = field( + default="Query: {text}\nUse one word to summarize the query's relevant information. The word is: \"", + metadata={ + "help": ( + "Query template. Ensure the template includes the placeholder " + "'{text}' to insert the actual query text." + ) + }, + ) + passage_template: str = field( + default="Text: {text}\nUse one word to summarize the text's content. The word is: \"", + metadata={ + "help": ( + "Passage template. Ensure the template includes the placeholder " + "'{text}' to insert the actual passage text." + ) + }, + ) + embedding_temperature: float = field( + default=0.02, + metadata={"help": "The temperature used in embedding learning."}, + ) + embedding_negatives_cross_device: bool = field( + default=True, + metadata={"help": "Whether to share the negatives across all GPUs."}, + ) + embedding_matryoshka_dims: Optional[List[int]] = field( + default=None, + metadata={"help": "The dims for matryoshka training."}, + ) diff --git a/paddlenlp/data/data_collator.py b/paddlenlp/data/data_collator.py index 78d3b3517ca0..d06953b4ee7a 100644 --- a/paddlenlp/data/data_collator.py +++ b/paddlenlp/data/data_collator.py @@ -39,6 +39,7 @@ "DataCollatorForSeq2Seq", "DataCollatorForLanguageModeling", "DataCollatorForWholeWordMask", + "DataCollatorForEmbedding", ] InputDataClass = NewType("InputDataClass", Any) @@ -417,6 +418,129 @@ def __call__(self, features, return_tensors=None): return batch +@dataclass +class DataCollatorForEmbedding: + tokenizer: PretrainedTokenizerBase + model: Optional[Any] = None + padding: Union[bool, str, PaddingStrategy] = True + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + return_tensors: str = "pd" + return_attention_mask: Optional[bool] = None + max_label_length: Optional[int] = None + + max_query_len: int = 512 + max_passage_len: int = 512 + + def __call__(self, batch, return_tensors=None) -> Any: + """Convert batch data into tensor.""" + input_keys = ["input_ids", "position_ids"] + + attn_key = "attention_mask" + input_keys.append(attn_key) + + # Initialize query and passage lists + queries = {key: [] for key in input_keys} + passages = {key: [] for key in input_keys} + + batch_query_embedding_indices = [] + batch_passage_embedding_indices = [] + + global_passage_idx = 0 + + # Process each batch sequence + for idx, batch_sequence in enumerate(batch): + query_data = [pair.query for pair in batch_sequence] + padded_query_token_ids, padded_query_position_ids, query_token_ids = self.process_data( + query_data, self.tokenizer.pad_token_id, self.max_query_len + ) + + queries["input_ids"].append(padded_query_token_ids) + queries["position_ids"].append(padded_query_position_ids) + batch_query_embedding_indices.append([idx, len(query_token_ids[0]) - 1]) + + queries[attn_key].append(self.gen_self_attn_mask(query_token_ids, self.max_query_len)) + + for pair in batch_sequence: + for passage in pair.passages: + passage_data = [passage] + padded_passage_token_ids, padded_passage_position_ids, passage_token_ids = self.process_data( + passage_data, self.tokenizer.pad_token_id, self.max_passage_len + ) + + passages["input_ids"].append(padded_passage_token_ids) + passages["position_ids"].append(padded_passage_position_ids) + batch_passage_embedding_indices.append([global_passage_idx, len(passage_token_ids[0]) - 1]) + + passages[attn_key].append(self.gen_self_attn_mask(passage_token_ids, self.max_passage_len)) + global_passage_idx += 1 + + for data in (queries, passages): + for k, v in data.items(): + data[k] = paddle.to_tensor(np.concatenate(v)) + + queries["embedding_indices"] = paddle.to_tensor(np.array(batch_query_embedding_indices, dtype="int32")) + passages["embedding_indices"] = paddle.to_tensor(np.array(batch_passage_embedding_indices, dtype="int32")) + + return { + "query": queries, + "passages": passages, + } + + def process_data(self, data, pad_idx, max_len): + """padding token_ids & position_ids.""" + token_ids = [sum((item.token_ids for item in data), [])] + position_ids = [sum((item.position_ids for item in data), [])] + padded_token_ids = self.pad_batch_data(token_ids, pad_id=pad_idx, max_seq_len=max_len) + padded_position_ids = self.pad_batch_data(position_ids, pad_id=0, max_seq_len=max_len) + return padded_token_ids, padded_position_ids, token_ids + + @staticmethod + def pad_batch_data(insts, pad_id=0, max_seq_len=None, return_seq_len=False, pad_style="right"): + """Pad sequences to the max sequence length in batch.""" + max_len = max_seq_len if max_seq_len is not None else max(map(len, insts)) + if pad_style == "left": + inst_data = np.array([[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts]) + else: + inst_data = np.array([list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts]) + + if return_seq_len: + seq_len = np.array([len(inst) for inst in insts]) + return inst_data.astype("int64").reshape([-1, max_len]), seq_len + else: + return inst_data.astype("int64").reshape([-1, max_len]) + + @staticmethod + def gen_self_attn_mask(batch_token_ids: List[List[int]], max_seq_len: int): + """Generate self attention mask for multiple sub-sequence.""" + input_mask_data = np.zeros((1, 1, max_seq_len, max_seq_len), dtype="float32") + offset = 0 + for index, token_ids in enumerate(batch_token_ids): + cur_len = len(token_ids) + b = np.tril(np.ones([cur_len, cur_len]), 0) + input_mask_data[0, 0, offset : offset + cur_len, offset : offset + cur_len] = b + offset += cur_len + return input_mask_data + + @staticmethod + def gen_attn_mask_start_row_indices(batch_token_ids: List[List[int]], max_seq_len: int, sliding_window: int): + """Generate attn_mask_start_row_indices for flash attention.""" + offset = 0 + attn_mask_start_row_indices = [] + for token_ids in batch_token_ids: + cur_len = len(token_ids) + if sliding_window > 0: + for i in range(cur_len): + attn_mask_start_row_indices.append(offset + min(cur_len, i + sliding_window)) + else: + attn_mask_start_row_indices.extend([offset + cur_len] * cur_len) + offset += cur_len + if offset < max_seq_len: + attn_mask_start_row_indices.extend(list(range(offset + 1, max_seq_len + 1))) + + return np.array(attn_mask_start_row_indices, dtype=np.int32)[None, None] + + def _paddle_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None): """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" import paddle diff --git a/paddlenlp/datasets/__init__.py b/paddlenlp/datasets/__init__.py index fda1d65868cf..49fb25fcf319 100644 --- a/paddlenlp/datasets/__init__.py +++ b/paddlenlp/datasets/__init__.py @@ -25,6 +25,7 @@ from .drcd import * from .drcd_cn import * from .dureader_robust import * +from .embedding_dataset import * from .glue import * from .imdb import * from .lcqmc import * diff --git a/paddlenlp/datasets/embedding_dataset.py b/paddlenlp/datasets/embedding_dataset.py new file mode 100644 index 000000000000..da34b9164e48 --- /dev/null +++ b/paddlenlp/datasets/embedding_dataset.py @@ -0,0 +1,252 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Embedding dataset.""" + +import random +from dataclasses import dataclass +from typing import List + +from paddle.io import Dataset, IterableDataset + +from ..utils.log import logger + + +@dataclass +class Example: + """Dataset example.""" + + query: str + pos_passage: List[str] + neg_passage: List[str] = None + + +@dataclass +class Sequence: + """Sequence.""" + + token_ids: List[int] + position_ids: List[int] + + +@dataclass +class Pair: + """Pair.""" + + query: Sequence + passages: List[Sequence] + + +class EmbeddingDatasetMixin: + """EmbeddingDatasetMixin.""" + + def convert_example(tokenizer, example): + """Convert raw json format example to Example.""" + + assert all( + (key in example for key in ["query", "pos_passage", "neg_passage"]) + ), "query, pos_passage, neg_passage are needed" + + if not isinstance(example["query"], str): + raise ValueError("query must be a string.") + if isinstance(example["pos_passage"], str): + example["pos_passage"] = [example["pos_passage"]] + if isinstance(example["neg_passage"], str): + example["neg_passage"] = [example["neg_passage"]] + + if len(example["neg_passage"]) > 0: + for item in [example["query"]] + example["pos_passage"] + example["neg_passage"]: + if not isinstance(item, str): + raise ValueError("The item in pos_passage / neg_passage must be a string.") + if len(item.strip()) == 0: + raise ValueError("Example with empty string in query / pos_passage / neg_passage field.") + + query = example["query"] + pos_passage = example["pos_passage"] + neg_passage = example["neg_passage"] + return Example(query=query, pos_passage=pos_passage, neg_passage=neg_passage) + + def tokenize_template(cls, tokenizer, template: str): + """Tokenize a given template using the provided tokenizer.""" + assert template.count("{text}") == 1, "Template must contain exactly one {text} placeholder" + + template_prefix, template_suffix = template.split("{text}") + + prefix_tokens = tokenizer(template_prefix, add_special_tokens=False).input_ids + suffix_tokens = tokenizer(template_suffix, add_special_tokens=False).input_ids + return prefix_tokens, suffix_tokens + + def _process_truncation(self, tokens, text_type): + """ + Process tokens by converting them into a complete token sequence with prefix and suffix, + and generate corresponding position ids. + """ + if text_type not in ["query", "passage"]: + raise ValueError("text_type must be either 'query' or 'passage'") + + prefix_key = f"{text_type}_template_prefix" + suffix_key = f"{text_type}_template_suffix" + max_len_key = f"max_{text_type}_len" + + # If the template does not contain a suffix token, add the EOS token to the end + if getattr(self, suffix_key) == []: + setattr(self, suffix_key, [self.tokenizer.eos_token_id]) + + # Calculate the available length + max_len = getattr(self, max_len_key) + prefix_tokens = getattr(self, prefix_key) + suffix_tokens = getattr(self, suffix_key) + available_len = int(max_len - len(prefix_tokens) - len(suffix_tokens)) + + # Convert tokens to ids and truncate + token_ids_converted = self.tokenizer.convert_tokens_to_ids(tokens) + truncated_token_ids = token_ids_converted[:available_len] + + # Combine prefix, truncated tokens, and suffix + token_ids = prefix_tokens + truncated_token_ids + suffix_tokens + pos_ids = list(range(len(token_ids))) + return token_ids, pos_ids + + def _postprocess_sequence(self, example: Example): + """Post process sequence: tokenization & truncation.""" + query = example.query + pos_passage = random.choice(example.pos_passage) + neg_passage = example.neg_passage + if len(neg_passage) > 0: + if len(neg_passage) < self.group_size - 1: + # Calculate how many full sets are needed to ensure each element appears at least once + full_sets_needed = (self.group_size - 1) // len(neg_passage) + remainder = (self.group_size - 1) % len(neg_passage) + + # Initialize the list and add complete sets + selected_neg_passage = neg_passage * full_sets_needed + + # Ensure the remainder part is filled; randomly select from neg_passage + selected_neg_passage += random.sample(neg_passage, remainder) + + # Shuffle the result to ensure randomness + random.shuffle(selected_neg_passage) + else: + selected_neg_passage = random.sample(neg_passage, self.group_size - 1) + else: + selected_neg_passage = [] + # Process query tokens + query_tokens = self.tokenizer.tokenize(query) + query_token_ids, query_pos_ids = self._process_truncation(query_tokens, "query") + + query = Sequence( + token_ids=query_token_ids, + position_ids=query_pos_ids, + ) + + # Process passage tokens + passages = [] + for passage in [pos_passage] + selected_neg_passage: + passage_tokens = self.tokenizer.tokenize(passage) + passage_token_ids, passage_pos_ids = self._process_truncation(passage_tokens, "passage") + passages.append( + Sequence( + token_ids=passage_token_ids, + position_ids=passage_pos_ids, + ) + ) + return Pair(query=query, passages=passages) + + +class EmbeddingDataset(EmbeddingDatasetMixin, Dataset): + def __init__( + self, + dataset, + tokenizer, + max_query_len: int = 64, + max_passage_len: int = 256, + group_size: int = 2, + query_template: str = "{text}", + passage_template: str = "{text}", + ): + super().__init__() + self.example_dataset = dataset + self.tokenizer = tokenizer + self.max_query_len = max_query_len + self.max_passage_len = max_passage_len + self.group_size = group_size + self.query_template = query_template + self.passage_template = passage_template + self.query_template_prefix, self.query_template_suffix = self.tokenize_template( + self.tokenizer, self.query_template + ) + self.passage_template_prefix, self.passage_template_suffix = self.tokenize_template( + self.tokenizer, self.passage_template + ) + + for index, data in enumerate(self.example_dataset): + self.example_dataset[index] = self.convert_example(data) + + def __getitem__(self, index): + return self._postprocess_sequence(self.example_dataset[index]) + + def __len__(self): + raise len(self.example_dataset) + + +class EmbeddingIterableDataset(EmbeddingDatasetMixin, IterableDataset): + """Create sequences from Example Dataset. + + This is a stateful dataset. + """ + + def __init__( + self, + dataset, + tokenizer, + max_query_len: int = 64, + max_passage_len: int = 256, + group_size: int = 2, + query_template: str = "{text}", + passage_template: str = "{text}", + ): + super().__init__() + self.example_dataset = dataset + self.tokenizer = tokenizer + self.max_query_len = max_query_len + self.max_passage_len = max_passage_len + self.group_size = group_size + self.query_template = query_template + self.passage_template = passage_template + self.query_template_prefix, self.query_template_suffix = self.tokenize_template( + self.tokenizer, self.query_template + ) + self.passage_template_prefix, self.passage_template_suffix = self.tokenize_template( + self.tokenizer, self.passage_template + ) + + self.epoch_index = 0 + + def __iter__(self): + while True: + logger.info(f"Start to load dataset on epoch={self.epoch_index}") + yield from self.iter_one_epoch() + + def iter_one_epoch(self): + """Iterates through one epoch of the dataset.""" + + num_sequences = 0 + for index, example in enumerate(self.example_dataset): + example = self.convert_example(example) + sequence = self._postprocess_sequence(example) + if sequence is None: + continue + num_sequences += 1 + yield [sequence] + + self.epoch_index += 1 diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index 35b3ea91f2b5..195f40e02188 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -23,15 +23,17 @@ import math import warnings from functools import partial -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import paddle +import paddle.distributed as dist import paddle.distributed.fleet.meta_parallel as mpu import paddle.nn.functional as F from paddle import Tensor, nn from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddlenlp.transformers.contrastive_loss import SimpleContrastiveLoss from paddlenlp.transformers.refined_recompute import ( RRColumnParallelLinear, RRColumnSequenceParallelLinear, @@ -45,6 +47,7 @@ from .. import linear_utils from ..activations import ACT2FN from ..conversion_utils import StateDictNameMapping, init_name_mappings +from ..embedding_utils import dist_gather_tensor_with_gradient from ..linear_utils import Linear from ..llama import fusion_ops from ..model_outputs import ( @@ -84,6 +87,7 @@ "Qwen2PretrainingCriterion", "Qwen2ForSequenceClassification", "Qwen2ForTokenClassification", + "Qwen2SentenceEmbedding", ] @@ -1662,3 +1666,80 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class Qwen2SentenceEmbedding(Qwen2PretrainedModel): + def __init__( + self, + config: Qwen2Config, + embedding_temperature: float = 0.02, + ): + """Qwen2SentenceEmbedding + For getting larger batch_size, we use tensor parallel to get larger batch_size. + + Args: + config (Qwen2Config): _description_ + model (Qwen2Model): _description_ + embedding_temperature (float, optional): _description_. Defaults to 0.02. + """ + super(Qwen2SentenceEmbedding, self).__init__(config) + self.config = config + self.qwen2 = Qwen2Model(config) + self.in_batch_negative_loss = SimpleContrastiveLoss(embedding_temperature) + self.world_size = dist.get_world_size() + self.process_rank = dist.get_rank() + self.embedding_negatives_cross_device = config.embedding_negatives_cross_device + if self.world_size <= 1: + self.embedding_negatives_cross_device = False + + def forward( + self, + query: Optional[Dict[str, paddle.Tensor]] = None, + passages: Optional[Dict[str, paddle.Tensor]] = None, + return_encode=False, + ): + """forward""" + q_reps = self.encode(**query) + p_reps = self.encode(**passages) + + q_reps = nn.functional.normalize(q_reps, axis=-1) + p_reps = nn.functional.normalize(p_reps, axis=-1) + + if return_encode: + return q_reps, p_reps + + if self.embedding_negatives_cross_device: + q_reps = dist_gather_tensor_with_gradient(q_reps) + p_reps = dist_gather_tensor_with_gradient(p_reps) + + loss = self.in_batch_negative_loss(q_reps, p_reps) + return loss + + def encode( + self, + input_ids, + position_ids=None, + embedding_indices=None, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + **kwargs, + ): + """encode""" + input_type = type(input_ids) + outputs = self.qwen2( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + if isinstance(outputs, input_type): + hidden_states = outputs + else: + hidden_states = outputs[0] + last_hidden_states = hidden_states.gather_nd(embedding_indices) + return last_hidden_states