Skip to content

Commit

Permalink
update readme.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Aug 23, 2023
1 parent 12fed8f commit 9814ddc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 50 deletions.
60 changes: 12 additions & 48 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@

## 😊 Feature

- [ChatGLM](textgen/chatglm):本项目基于PyTorch实现了ChatGLM-6B模型LoRA微调训练和预测,可以用于句子纠错、对话等文本生成任务
- [LLaMA](textgen/llama):本项目基于PyTorch实现了LLaMA模型LoRA微调训练和预测,可以用于对话生成任务和领域微调训练
- [BLOOM](textgen/bloom):本项目基于PyTorch实现了BLOOM模型LoRA微调训练和预测,可以用于对话生成任务和领域微调训练
- [GPT](textgen/gpt):本项目基于PyTorch实现了ChatGLM-6B/Baichuan/LLaMA2/BLOOM等GPT模型LoRA微调训练和预测,可以用于对话生成任务和领域微调训练
- [UDA/EDA](textgen/augment/word_level_augment.py):本项目实现了UDA(非核心词替换)、EDA和Back Translation(回译)算法,基于TF-IDF将句子中部分不重要词替换为同义词,随机词插入、删除、替换等方法,产生新的文本,实现了文本扩增
- [Seq2Seq](textgen/seq2seq):本项目基于PyTorch实现了Seq2Seq、ConvSeq2Seq、BART模型的训练和预测,可以用于文本翻译、对话生成、摘要生成等文本生成任务
- [T5](textgen/t5):本项目基于PyTorch实现了T5和CopyT5模型训练和预测,可以用于文本翻译、对话生成、对联生成、文案撰写等文本生成任务
Expand Down Expand Up @@ -126,22 +124,20 @@ python setup.py install
example: [examples/chatglm/inference_demo.py](https://github.com/shibing624/textgen/blob/main/examples/chatglm/inference_demo.py)

```python
from textgen import ChatGlmModel
from textgen import GptModel

model = ChatGlmModel("chatglm", "THUDM/chatglm-6b", peft_name="shibing624/chatglm-6b-csc-zh-lora")
r = model.predict(["对下面中文拼写纠错:\n少先队员因该为老人让坐。\n答:"])
print(r) # ['少先队员应该为老人让座。\n错误字:因,坐']
model = GptModel("chatglm", "THUDM/chatglm-6b", peft_name="shibing624/chatglm-6b-csc-zh-lora")
r = model.predict(["介绍下北京"])
print(r) # ['北京是中国的首都...']
```

PS:由于使用了开发中的peft库,可能由于版本更新,导致LoRA模型加载失败,建议使用下面的训练方法,自己训练LoRA模型。

#### 训练 ChatGLM-6B 微调模型

1. 支持自定义训练数据集和训练参数,数据集格式参考[examples/data/zh_csc_test.tsv](https://github.com/shibing624/textgen/blob/main/examples/data/zh_csc_test.tsv)或者[examples/data/json_files/belle_10.json](https://github.com/shibing624/textgen/blob/main/examples/data/json_files/belle_10.json)
2. 支持AdaLoRA、LoRA、P_Tuning、Prefix_Tuning等部分参数微调方法,也支持全参微调
1. 支持自定义训练数据集和训练参数,数据集格式参考[examples/data/sharegpt_zh_100_format.jsonl](https://github.com/shibing624/textgen/blob/main/examples/data/sharegpt_zh_100_format.jsonl)
2. 支持QLoRA、AdaLoRA、LoRA、P_Tuning、Prefix_Tuning等部分参数微调方法,也支持全参微调
3. 支持多卡训练,支持混合精度训练

example: [examples/chatglm/training_chatglm_demo.py](https://github.com/shibing624/textgen/blob/main/examples/chatglm/training_chatglm_demo.py)
example: [examples/gpt/training_chatglm_demo.py](https://github.com/shibing624/textgen/blob/main/examples/gpt/training_chatglm_demo.py)

单卡训练:
```shell
Expand All @@ -155,24 +151,6 @@ cd examples/chatglm
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 training_chatglm_demo.py --do_train --do_predict --num_epochs 20
```


#### 基于微调(LoRA)模型继续训练
如果需要基于Lora模型继续训练,可以使用下面的脚本合并模型为新的base model,再微调训练即可。

执行以下命令:
```shell
python -m textgen/chatglm/merge_peft_adapter.py \
--base_model_name_or_path path_to_original_base_model_dir \
--peft_model_path path_to_peft_model_dir \
--output_dir path_to_output_dir
```
参数说明:
```
--base_model_name_or_path:存放HF格式的底座模型权重和配置文件的目录
--peft_model_path:存放PEFT格式的微调模型权重和配置文件的目录
--output_dir:指定保存全量模型权重的目录,默认为./merged
```

### LLaMA 模型

#### 使用 LLaMA 微调后的模型
Expand All @@ -188,32 +166,24 @@ import sys
sys.path.append('../..')
from textgen import GptModel


def generate_prompt(instruction):
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:{instruction}\n\n### Response:"""


model = GptModel("llama", "decapoda-research/llama-7b-hf", peft_name="ziqingyang/chinese-alpaca-lora-7b")
predict_sentence = generate_prompt("问:用一句话描述地球为什么是独一无二的。\n答:")
r = model.predict([predict_sentence])
r = model.predict(["用一句话描述地球为什么是独一无二的。"])
print(r) # ['地球是唯一一颗拥有生命的行星。']
```

</details>

#### 训练 LLaMA 微调模型
1. 支持自定义训练数据集和训练参数,数据集格式参考[examples/data/zh_csc_test.tsv](https://github.com/shibing624/textgen/blob/main/examples/data/zh_csc_test.tsv)或者[shibing624/alpaca-zh](https://huggingface.co/datasets/shibing624/alpaca-zh)
2. 支持AdaLoRA、LoRA、P_Tuning、Prefix_Tuning等部分参数微调方法,也支持全参微调
1. 支持自定义训练数据集和训练参数,数据集格式参考[examples/data/sharegpt_zh_100_format.jsonl](https://github.com/shibing624/textgen/blob/main/examples/data/sharegpt_zh_100_format.jsonl)
2. 支持QLoRA、AdaLoRA、LoRA、P_Tuning、Prefix_Tuning等部分参数微调方法,也支持全参微调
3. 支持多卡训练,支持混合精度训练,使用方法同上(ChatGLM多卡训练)

example: [examples/llama/training_llama_demo.py](https://github.com/shibing624/textgen/blob/main/examples/llama/training_llama_demo.py)
example: [examples/gpt/training_llama_demo.py](https://github.com/shibing624/textgen/blob/main/examples/gpt/training_llama_demo.py)


#### 基于微调(LoRA)模型继续训练
如果需要基于Lora模型继续训练,可以使用下面的脚本合并模型为新的base model,再微调训练即可。

单LoRA权重合并(适用于 Chinese-LLaMA, Chinese-LLaMA-Plus, Chinese-Alpaca)

执行以下命令:
```shell
python -m textgen/gpt/merge_peft_adapter.py \
Expand All @@ -234,12 +204,6 @@ python -m textgen/gpt/merge_peft_adapter.py \
#### 训练领域模型
Note: 为了全面的介绍训练医疗大模型的过程,把4阶段训练方法(Pretraining, Supervised Finetuning, Reward Modeling and Reinforcement Learning)单独新建了一个repo:[shibing624/MedicalGPT](https://github.com/shibing624/MedicalGPT),请移步该repo查看训练方法。

### BLOOM 模型

#### 训练 BLOOM 微调模型

example: [examples/bloom/training_bloom_demo.py](https://github.com/shibing624/textgen/blob/main/examples/bloom/training_bloom_demo.py)

### ConvSeq2Seq 模型

训练并预测ConvSeq2Seq模型:
Expand Down
4 changes: 2 additions & 2 deletions textgen/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers import (
AutoConfig,
LlamaForCausalLM,
LlamaTokenizerFast,
LlamaTokenizer,
BloomTokenizerFast,
BloomForCausalLM,
AutoModelForCausalLM,
Expand All @@ -48,7 +48,7 @@
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

MODEL_CLASSES = {
"llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizerFast),
"llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
"chatglm": (AutoConfig, AutoModel, AutoTokenizer),
"bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
"baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
Expand Down

0 comments on commit 9814ddc

Please sign in to comment.