Skip to content

Commit

Permalink
add steerlm
Browse files Browse the repository at this point in the history
Signed-off-by: Zhilin Wang <zhilinw@nvidia.com>
  • Loading branch information
Zhilin123 committed Dec 2, 2023
1 parent 94288ef commit a9339f7
Show file tree
Hide file tree
Showing 22 changed files with 1,099 additions and 227 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The toolkit is currently in it's early stages, and we are committed to improving

## Key features

* **SteerLM**
* **SteerLM: Attribute Conditioned SFT as an (User-Steerable) Alternative to RLHF.** Learn more at our [SteerLM](https://arxiv.org/abs/2310.05344) and [HelpSteer](https://arxiv.org/abs/2311.09528) papers. Try it instantly for free on [NVIDIA AI Playground](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/models/llama2-70b-steerlm)
* **Supervised Fine Tuning**
* **Reward Model Training**
* **Reinforcement Learning from Human Feedback using the PPO Algorithm**
Expand Down
2 changes: 1 addition & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## Custom Trainers

NeMo-Aligner uses custom trainers to coordinate all aspects of training. There are currently 3 custom trainers:
1. [SupervisedTrainer](/nemo_aligner/algorithms/supervised.py): for SFT and Reward modeling.
1. [SupervisedTrainer](/nemo_aligner/algorithms/supervised.py): for SFT, SteerLM and Reward modeling.
2. [CriticServerTrainer](/nemo_aligner/algorithms/critic_server_trainer.py): trains the RL critic via PyTriton requests. It will also run the reward model depending on the configuration.
3. [PPOTrainer](/nemo_aligner/algorithms/ppo.py): performs the RLHF PPO training, since PPO has components such as the Critic, this trainer will send inference and train requests via [PyTriton](https://github.com/triton-inference-server/pytriton) to the CriticServerTrainer to train and run inference on the critic.

Expand Down
385 changes: 184 additions & 201 deletions docs/user-guide/SteerLM.rst

Large diffs are not rendered by default.

154 changes: 154 additions & 0 deletions examples/nlp/data/steerlm/attribute_annotate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script is for annotating attributes for a dataset by sending requests to a regression reward model server.
"""


import argparse
import json
import os
from typing import List

import jsonlines
import numpy as np
from common import (
ALL_STEERLM_ATTRIBUTES,
ASSISTANT_TURN_TEMPLATE,
LABEL_PREFIX,
SYSTEM_PROMPT,
SYSTEM_PROMPT_TEMPLATE,
USER_TURN_TEMPLATE,
)
from pytriton.client import FuturesModelClient
from tqdm import tqdm, trange


def _str_list2numpy(str_list: List[str]) -> np.ndarray:
str_ndarray = np.array(str_list)[..., np.newaxis]
return np.char.encode(str_ndarray, "utf-8")


def prepare_args():
parser = argparse.ArgumentParser()
parser.add_argument("--output-file", type=str, required=True)
parser.add_argument("--input-file", type=str, required=True)
parser.add_argument("--port", type=int, default=5555)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--model_name", type=str, default="reward_model")
parser.add_argument("--add-eos", action="store_true")
return parser.parse_args()


def get_reward(
sentences: List[str], add_EOS=False, host="localhost", port=5555, model_name="reward_model",
):
sentences = _str_list2numpy(sentences)

futures = []

with FuturesModelClient(f"{host}:{port}", model_name) as client:
for sen in np.split(sentences, sentences.shape[0]):
add_EOS_arr = np.ones_like(sen, dtype=bool) * add_EOS
future = client.infer_batch(sentences=sen, add_EOS=add_EOS_arr)
futures.append(future)

all_result_dicts = [f.result() for f in futures]

all_rewards, all_exceeded = [], []

for output_dict in all_result_dicts:
reward_out = output_dict["rewards"].flatten().tolist()

all_rewards.append(reward_out)
all_exceeded += output_dict["exceeded"].tolist()

return all_rewards, all_exceeded


def get_key(l):
convs = [c["value"] for c in l["conversations"]]
return "".join(convs)


def main(args):
inference_output = args.output_file

exist = set()
if os.path.exists(inference_output):
with jsonlines.open(inference_output) as reader:
for obj in tqdm(reader):
exist.add(get_key(obj))

fout = open(inference_output, "a", encoding="utf-8")

# to warm up the jit
_ = get_reward(["hello world!"], add_EOS=args.add_eos, host=args.host, port=args.port, model_name=args.model_name)

all_samples, inputs = [], []

with jsonlines.open(args.input_file) as reader:
for obj in tqdm(reader):
if get_key(obj) in exist:
continue
user = obj["mask"]
turns = []
text = SYSTEM_PROMPT_TEMPLATE.format(value=SYSTEM_PROMPT)
for turn in obj["conversations"]:
value = turn["value"]
if turn["from"] == user:
text += USER_TURN_TEMPLATE.format(value=value)
else:
text += ASSISTANT_TURN_TEMPLATE.format(value=value)
if "label" in turn and turn["label"] is not None:
out_text = text + LABEL_PREFIX
turns.append(out_text)

all_samples.append(turns)
inputs.append(obj)

print(f"exist {len(exist)}, rest {len(inputs)}")
if len(inputs) == 0:
exit(0)

for idx in trange(0, len(all_samples)):
input = inputs[idx]
sample = all_samples[idx]
rewards_all, _ = get_reward(
sample, add_EOS=args.add_eos, host=args.host, port=args.port, model_name=args.model_name
)

t = 0
for turn in input["conversations"]:
if "label" in turn and turn["label"] is not None:
reward = rewards_all[t]
t += 1

reward_each = [min(4.0, max(0.0, float(r))) for r in reward]
reward_each = [round(r) for r in reward_each]

reward_string = ",".join(f"{a}:{r}" for a, r in zip(ALL_STEERLM_ATTRIBUTES, reward_each))
turn["label"] = reward_string

assert t == len(rewards_all)

fout.write(json.dumps(input) + "\n")

print("all annotations finished")
fout.close()


if __name__ == "__main__":
main(prepare_args())
18 changes: 18 additions & 0 deletions examples/nlp/data/steerlm/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
SYSTEM_PROMPT = (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)

SYSTEM_PROMPT_TEMPLATE = "<extra_id_0>System\n{value}\n"

USER_TURN_TEMPLATE = "<extra_id_1>User\n{value}\n"

ASSISTANT_TURN_TEMPLATE = "<extra_id_1>Assistant\n{value}\n"

LABEL_PREFIX = "<extra_id_2>"

OPEN_ASSISTANT_ATTRIBUTES = ["quality", "toxicity", "humor", "creativity"]

HELPSTEER_ATTRIBUTES = ["helpfulness", "correctness", "coherence", "complexity", "verbosity"]

ALL_STEERLM_ATTRIBUTES = OPEN_ASSISTANT_ATTRIBUTES + HELPSTEER_ATTRIBUTES
82 changes: 82 additions & 0 deletions examples/nlp/data/steerlm/preprocess_helpsteer_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script is to preprocess HelpSteer dataset from HuggingFace format into Attribute-conditioned SFT training format.
"""

import argparse
import json
import os

from common import HELPSTEER_ATTRIBUTES, SYSTEM_PROMPT
from datasets import load_dataset


def download_helpsteer():
ds = load_dataset("nvidia/HelpSteer")
train = ds["train"]
val = ds["validation"]
return train, val


def format_label(dp):
label_list = []
for attr in HELPSTEER_ATTRIBUTES:
label_list.append(f"{attr}:{dp[attr]}")
return ",".join(label_list)


def process_dataset(data):
output = []
for dp in data:
conversation_obj = {}
conversation_obj["conversations"] = [
{"value": dp["prompt"], "from": "User", "label": None},
{"value": dp["response"], "from": "Assistant", "label": format_label(dp)},
]
conversation_obj["system"] = SYSTEM_PROMPT
conversation_obj["mask"] = "User"
conversation_obj["type"] = "VALUE_TO_TEXT"
output.append(conversation_obj)
return output


def main(output_dir):
train, val = download_helpsteer()

os.makedirs(output_dir, exist_ok=True)
processed_train = process_dataset(train)
with open(f"{output_dir}/train.jsonl", "w", encoding="utf-8") as f:
for record in processed_train:
f.write(json.dumps(record, ensure_ascii=False) + "\n")

processed_val = process_dataset(val)
with open(f"{output_dir}/val.jsonl", "w", encoding="utf-8") as f:
for record in processed_val:
f.write(json.dumps(record, ensure_ascii=False) + "\n")


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"-dir",
"--output_directory",
required=True,
help="folder to store the created train.jsonl and val.jsonl; will be created if it does not exist",
)
args = parser.parse_args()

main(args.output_directory)
Loading

0 comments on commit a9339f7

Please sign in to comment.