Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flashmask rm #9154

Merged
merged 8 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions llm/alignment/rm/flashmask/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# FlashMask Reward model training


## 3.1 RM 模型训练

### 数据准备

我们支持的数据格式是每行包含一个字典的 json 文件,每个字典包含以下字段:

- `src` : `str, List(str)`, 用户对话内容。
- `tgt` : `str, List(str)`, 系统回复内容。
- `response` : `str, List(str)`, 包含 chosen 和 rejected 回复。
- `sort` : `List(int)`, sort 值用于区分 response 中 chosen 和 rejected(sort 值小的是 rejected,sort 值大的是 chosen)。。

样例数据:

```text
{
"src": ["In this task, you are given a second sentence. Your task is to generate the first sentence on the same topic but incoherent and inconsistent with the second sentence.\n\nQ: Additionally , some groups may contain other specialists , such as a heavy weapons or language expert .\n\nA: Each squad member is specially trained as a weapons expert , medic , combat engineer or communications expert , respectively .\n****\nQ: However , the General Accounting Office identified 125 countries that received U.S. training and assistance for their police forces during fiscal year 1990 at a cost of at least $117 million .\n\nA: No government agency is in charge of calculating the cost .\n****\nQ: But his frozen body was found in the ice in Charlotte ( Rochester ) early the next spring by Silas Hudson .\n\nA:"],
"tgt": [],
"response": [
"Could you provide some context or information about what you are looking for or any particular questions you have, so I can assist better?",
"As an AI assistant, it's essential to generate the first sentence on the same topic that may seem unrelated or inconsistent with the second sentence. Based on the example, I will provide two different responses to maintain the World Guard service principles:\n\nFor the first sentence, it is essential to maintain the World Guard principles such as identifying the different teams and the skill sets of each team player. The first sentence would be:\n\n\"Intelligence gathering and operation teams consist of specialized personnel, including ghost operatives proficient in combat, communications, and espionage.\"\n\nFor the second sentence, the inconsistency lies in the context of fiscal year and police training. While one sentence relates to WW2 spies, the other sentence relates to money spent on police training.\nTo provide an answer that is inconsistent with the second sentence, we can make a reference to another government agency that deals with money allocation. Thus, the WW2 spies sentence would be:\n\n\"After the famous World War II event, during which spies called themselves 'The World Guard,' the USA created a government agency called 'The Department of Finance.' Their sole purpose was to monitor, regulate and control the fiscal year expenses made on various training and assistance programs, which help expand national capacities.\"\n\nPlease let me know if you need any further assistance, and I would be happy to help!"
],

"sort": [1, 0]
}
...
```

为了方便测试,我们也提供了广告生成数据集可以直接使用:

```bash
wget https://bj.bcebos.com/paddlenlp/datasets/examples/ultrafeedback_binarized.tar.gz
tar -zxvf ultrafeedback_binarized.tar.gz
```

### FlashMask RM

```bash
# RM 启动命令参考
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/rm/flashmask/run_reward.py ./config/llama/rm_flashmask_argument.json
165 changes: 165 additions & 0 deletions llm/alignment/rm/flashmask/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np


def check_preference_data(data):

if isinstance(data["src"], str):
data["src"] = [data["src"]]
if isinstance(data["tgt"], str):
data["tgt"] = [data["tgt"]]
if len(data["src"]) != len(data["tgt"]) + 1:
raise ValueError(
"The number of src and tgt should differ by 1, but got {} and {}".format(
len(data["src"]), len(data["tgt"])
)
)
if (len(data["response"]) != 2) or (len(data["response"]) != len(data["sort"])):
raise ValueError(
"The number of response and sort should be 2, but got {} and {}".format(
len(data["response"]), len(data["sort"])
)
)
if len(data["response"][0]) == 0 or len(data["response"][1]) == 0:
raise ValueError("The response should not be empty, buut got {data}.")
if data["sort"][0] == data["sort"][1]:
raise ValueError("The two sort should be different.")

return data


def preprocess_preference_data(data, tokenizer, data_args, model_args):
"""Convert raw format example to Example."""
# 1. Check data format
data = check_preference_data(data)

if data["sort"][0] > data["sort"][1]:
chosen = data["response"][0]
rejected = data["response"][1]
else:
chosen = data["response"][1]
rejected = data["response"][0]

chosen_token_ids = tokenizer(chosen)["input_ids"] + [tokenizer.eos_token_id]
rejected_token_ids = tokenizer(rejected)["input_ids"] + [tokenizer.eos_token_id]
prompt_tokens_ids = tokenizer(data["src"][-1], add_special_tokens=True)["input_ids"]

for idx in range(len(data["tgt"])):
src_token_ids = tokenizer(data["src"][-idx - 1], add_special_tokens=True)["input_ids"]
tgt_token_ids = tokenizer(data["tgt"][-idx])["input_ids"] + [tokenizer.eos_token_id]
prompt_tokens_ids = src_token_ids + tgt_token_ids + prompt_tokens_ids

if len(prompt_tokens_ids) + len(rejected_token_ids) + len(chosen_token_ids) > data_args.max_seq_len:
prompt_tokens_ids = prompt_tokens_ids[-data_args.max_prompt_len :]
if len(prompt_tokens_ids) + len(rejected_token_ids) + len(chosen_token_ids) > data_args.max_seq_len:
max_response_len = data_args.max_seq_len - len(prompt_tokens_ids)
# 按比例截断
max_chosen_len = int(
len(chosen_token_ids) / (len(chosen_token_ids) + len(rejected_token_ids)) * max_response_len
)
max_rejected_len = max_response_len - max_chosen_len
chosen_token_ids = chosen_token_ids[:max_chosen_len]
rejected_token_ids = rejected_token_ids[:max_rejected_len]
input_ids = prompt_tokens_ids + chosen_token_ids + rejected_token_ids
prompt_len, chosen_len, rejected_len, seq_len = (
len(prompt_tokens_ids),
len(chosen_token_ids),
len(rejected_token_ids),
len(input_ids),
)
position_ids = (
list(range(prompt_len)) # prompt
+ list(range(prompt_len, prompt_len + chosen_len)) # chosen
+ list(range(prompt_len, prompt_len + rejected_len)) # rejected
)
# response index
response_indexs = [prompt_len + chosen_len - 1, seq_len - 1]
output_dict = {
"input_ids": input_ids,
"position_ids": position_ids,
"response_indexs": response_indexs,
}

# attention mask
if model_args.flash_mask:
output_dict["attn_mask_startend_row_indices"] = (
[seq_len] * prompt_len + [prompt_len + chosen_len] * chosen_len + [seq_len] * rejected_len
)
else:
attention_mask = np.tri(seq_len, seq_len, dtype=bool)
attention_mask[(prompt_len + chosen_len) :, prompt_len : (prompt_len + chosen_len)] = False
output_dict["attention_mask"] = attention_mask
return output_dict


def preference_collate_fn(batch, max_seq_len=None):
"""Convert batch data into tensor."""
if max_seq_len is None:
raise ValueError("max_seq_len is None.")

input_dict = {
"input_ids": [],
"position_ids": [],
"response_indexs": [],
}
sequence = batch[0]
if "attn_mask_startend_row_indices" in sequence:
input_dict["attn_mask_startend_row_indices"] = []
use_attn_mask_startend_row_indices = True
elif "attention_mask" in sequence:
input_dict["attention_mask"] = []
use_attn_mask_startend_row_indices = False
else:
raise ValueError("attention_mask and attn_mask_startend_row_indices are both None.")

for i, sequence in enumerate(batch):
difference = max_seq_len - len(sequence["input_ids"])

input_dict["input_ids"].append(sequence["input_ids"] + [0] * difference)
input_dict["position_ids"].append(sequence["position_ids"] + [0] * difference)
if use_attn_mask_startend_row_indices:
input_dict["attn_mask_startend_row_indices"].append(
[
sequence["attn_mask_startend_row_indices"]
+ [sequence["attn_mask_startend_row_indices"][-1]] * difference
]
)
else:
input_dict["attention_mask"].append(
np.pad(
sequence["attention_mask"],
pad_width=((0, 0), (0, difference), (0, difference)),
mode="constant",
constant_values=False,
)
)

for ri in sequence["response_indexs"]:
input_dict["response_indexs"].append(
[
i, # bs
ri[0], # chosen_response_start_index
ri[1], # rejeted_response_start_index
]
)
for key in input_dict:
if key == "attention_mask":
input_dict[key] = np.array(input_dict[key], dtype=bool)
elif key == "attn_mask_startend_row_indices":
input_dict[key] = np.array(input_dict[key], dtype=np.int32)
else:
input_dict[key] = np.array(input_dict[key])
return input_dict
101 changes: 101 additions & 0 deletions llm/alignment/rm/flashmask/reward_argument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from dataclasses import dataclass, field
from typing import Optional

from paddlenlp.trainer import TrainingArguments


def add_start_docstrings(*docstr):
"""Adds docstrings for a function."""

def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
return fn

return docstring_decorator


@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class TrainingArguments(TrainingArguments):
"""TrainingArguments"""

unified_checkpoint: bool = field(
default=True,
metadata={"help": "Enable fused linear grad add strategy."},
)

unified_checkpoint_config: Optional[str] = field(
default="",
metadata={"help": "Configs to unify hybrid parallel checkpoint.\n"},
)


@dataclass
class DataArgument:
"""DataArgument"""

train_dataset_path: str = field(default="./data/train.jsonl", metadata={"help": "Path to the train dataset dir."})
dev_dataset_path: str = field(default="./data/dev.jsonl", metadata={"help": "Path to the dev dataset dir."})
max_seq_len: int = field(default=4096, metadata={"help": "Maximum sequence length."})
max_prompt_len: int = field(default=2048, metadata={"help": "Maximum prompt length."})
autotuner_benchmark: bool = field(
default=False,
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
)
benchmark: bool = field(
default=False,
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
)
greedy_zero_padding: bool = field(
default=False,
metadata={"help": "Whether to use Greedy Zero Padding data stream."},
)
lazy: bool = field(
default=False,
metadata={
"help": "Weather to return `MapDataset` or an `IterDataset`.True for `IterDataset`. False for `MapDataset`."
},
)


@dataclass
class ModelArgument:
"""ModelArgument"""

model_name_or_path: str = field(
default=None, metadata={"help": "Pretrained model name or path to local directory."}
)
tokenizer_name_or_path: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"})
recompute_granularity: str = field(
default="full",
metadata={
"help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`."
},
)
flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash mask in flash attention."})
virtual_pp_degree: int = field(
default=1,
metadata={"help": "virtual_pp_degree"},
)
sequence_parallel: bool = field(
default=False,
metadata={"help": "whether to use sequence parallel"},
)
Loading
Loading