Skip to content

Commit

Permalink
feat: support new DPO data format and update SFT config to use overri…
Browse files Browse the repository at this point in the history
…de API (#405)

Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: arendu <adithya.r@gmail.com>
Signed-off-by: NeMo-Aligner CI <nemo-aligner-ci@nvidia.com>
Co-authored-by: Terry Kong <terryk@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 4, 2024
1 parent 70e4f31 commit 5d4b2a7
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 85 deletions.
3 changes: 3 additions & 0 deletions examples/nlp/gpt/conf/gpt_sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ model:
attention_dropout: 0.0
ffn_dropout: 0.0

global_batch_size: ${.data.train_ds.global_batch_size}
micro_batch_size: ${.data.train_ds.micro_batch_size}

steerlm2:
forward_micro_batch_size: 1 # the micro batch size for the forward pass, used to compute the weights
micro_batch_size: 1 # the steerlm2 training micro batch size
Expand Down
83 changes: 6 additions & 77 deletions examples/nlp/gpt/train_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
resolve_and_create_trainer,
retrieve_custom_trainer_state_dict,
)
from nemo_aligner.utils.utils import load_from_nemo
from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo

"""Script to start SFT training"""

Expand All @@ -49,75 +49,10 @@
mp.set_start_method("spawn", force=True)


def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
"""
This function modifies the original gpt pre-training config (gpt_cfg) with attributes from the finetuning config (cfg).
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`.
"""
OmegaConf.set_struct(gpt_cfg, True)
OmegaConf.resolve(cfg)
with open_dict(gpt_cfg):
gpt_cfg.megatron_amp_O2 = cfg.model.get("megatron_amp_O2", False)
gpt_cfg.micro_batch_size = cfg.model.data.train_ds.micro_batch_size
gpt_cfg.global_batch_size = cfg.model.data.train_ds.global_batch_size
gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False)
gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None)
gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None)
gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None)
gpt_cfg.activations_checkpoint_layers_per_pipeline = cfg.model.get(
"activations_checkpoint_layers_per_pipeline", None
)
gpt_cfg.peft = cfg.model.peft
gpt_cfg.data = cfg.model.data
gpt_cfg.optim = cfg.model.optim
gpt_cfg.precision = cfg.trainer.precision
gpt_cfg.answer_only_loss = cfg.model.answer_only_loss
gpt_cfg.restore_from_path = cfg.model.restore_from_path
gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint
gpt_cfg.save_nemo_on_validation_end = cfg.model.save_nemo_on_validation_end
gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view
gpt_cfg.hidden_dropout = cfg.model.get("hidden_dropout", 0.0)
gpt_cfg.attention_dropout = cfg.model.get("attention_dropout", 0.0)
gpt_cfg.ffn_dropout = cfg.model.ffn_dropout
gpt_cfg.use_flash_attention = cfg.model.get("use_flash_attention", False)
# if TP/PP size is -1, use default TP/PP size as original model
if cfg.model.get("tensor_model_parallel_size", 1) > 0:
gpt_cfg.tensor_model_parallel_size = cfg.model.get("tensor_model_parallel_size", 1)
if cfg.model.get("pipeline_model_parallel_size", 1) > 0:
gpt_cfg.pipeline_model_parallel_size = cfg.model.get("pipeline_model_parallel_size", 1)
gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get("pipeline_model_parallel_split_rank", 0)

if cfg.model.data.get("chat", False):
# chat model, overwrite the prompt template
prompt_template = get_prompt_template_example(cfg.model.data.chat_prompt_tokens)
gpt_cfg.data.train_ds.prompt_template = prompt_template
gpt_cfg.data.validation_ds.prompt_template = prompt_template

sft_cls = GPTSFTModel
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"

if cfg.model.get("use_flash_attention", None) is not None:
gpt_cfg.use_flash_attention = cfg.model.use_flash_attention

if cfg.model.get("seq_len_interpolation_factor", None) is not None:
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor

if cfg.model.get("dist_ckpt_load_strictness", None) is not None:
gpt_cfg.dist_ckpt_load_strictness = cfg.model.dist_ckpt_load_strictness

gpt_cfg.inference = cfg.model.get("inference", {})

# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
OmegaConf.resolve(gpt_cfg)
gpt_cfg.cfg = gpt_cfg

return gpt_cfg


@hydra_runner(config_path="conf", config_name="gpt_sft")
def main(cfg) -> None:
cfg.model = load_and_override_model_config(cfg.model.restore_from_path, cfg.model)

logging.info("\n\n************** Experiment configuration ***********")
logging.info(f"\n{OmegaConf.to_yaml(cfg)}")

Expand All @@ -129,17 +64,11 @@ def main(cfg) -> None:
with open_dict(cfg):
cfg.model.precision = cfg.trainer.precision

ptl_model, updated_cfg = load_from_nemo(
GPTSFTModel,
cfg,
trainer,
strict=True,
modify_config_fn=_modify_config,
restore_path=cfg.model.restore_from_path,
return_updated_cfg=True,
ptl_model = load_from_nemo(
GPTSFTModel, cfg, trainer, strict=True, restore_path=cfg.model.restore_from_path, return_updated_cfg=False,
)

init_peft(ptl_model, updated_cfg)
init_peft(ptl_model, cfg.model)

with open_dict(cfg):
# overwrite the model config with the config from the checkpoint
Expand Down
103 changes: 95 additions & 8 deletions nemo_aligner/data/nlp/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@
"""Custom datasets for RLHF training"""

import os
from typing import Dict, List

import numpy as np
import scipy
import torch
from omegaconf import OmegaConf

from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import _create_ltor_masks_and_position_ids
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import (
GPTSFTChatDataset,
_get_header_conversation_type_mask_role,
get_prompt_template_example,
)
from nemo.core import Dataset
from nemo.utils import logging

Expand Down Expand Up @@ -344,16 +350,97 @@ def encode(self, text, append_eod=False):

return text_ids, len(text_ids)

@staticmethod
def _convert_messages(
input_list: List[Dict[str, str]]
) -> Dict: # TODO: (@adithyare) this method should live elsewhare..
"""
args:
input_list: is a list of dicts in the openai format
for example:
[{"role": "system", "content": "you are helpful},
{"role": "user", "content": "Why is the sky blue?"},
{"role": "assistant", "content": "Because blablabla"},
...]
returns:
output_dict: a dict in nemo's format {"system": "sytem prompt",
"conversation": [],
...
}
"""
output_dict = {
"system": "",
"conversations": [],
"mask": "User",
"type": "VALUE_TO_TEXT",
}

# Extract the system message
num_system_msg = 0
for msg in input_list:
if msg["role"] == "system":
output_dict["system"] = msg["content"]
num_system_msg += 1
if num_system_msg > 1:
raise RuntimeError("Multiple system messages seen, please consolidate into a single system message.")

# Build the conversations list
for msg in input_list:
if msg["role"] != "system":
conversation_entry = {
"from": msg["role"].capitalize(), # Capitalize 'user' and 'assistant'
"value": msg["content"],
"label": None,
}
output_dict["conversations"].append(conversation_entry)

return output_dict

def convert(self, messages):
"""
args:
messages: is a list of dicts in the openai format
for example:
[{"role": "system", "content": "you are helpful},
{"role": "user", "content": "Why is the sky blue?"},
{"role": "assistant", "content": "Because blablabla"},
...]
returns:
conversation: is a string formatted with the chat template
"""
if OmegaConf.select(self.cfg, "data.chat_prompt_tokens") is None:
raise RuntimeError(
"You don't have a model (model_config.yaml) which has chat_prompt_tokens, are you sure this is a Chat/Instruction model?"
)
special_tokens = self.cfg.data.chat_prompt_tokens
nemo_source = self._convert_messages(messages)
header, conversation, data_type, mask_role = _get_header_conversation_type_mask_role(
nemo_source, special_tokens
)
return conversation

def __getitem__(self, idx):
"""Returns a pair of chosen/rejected pairs, their respective lengths, and labels."""
payload = self.data[idx]
prompt, prompt_len = self.encode(payload["prompt"], append_eod=False)
chosen, chosen_len = self.encode(
payload["prompt"] + payload["chosen_response"], append_eod=self.cfg.data.get("append_eod", False)
)
reject, reject_len = self.encode(
payload["prompt"] + payload["rejected_response"], append_eod=self.cfg.data.get("append_eod", False)
)

if isinstance(payload["prompt"], str):
# (@adithyare) format with hardcoded chat tokens
# will allow this for the time being.
prompt_fmtd = payload["prompt"]
chosen_fmtd = payload["prompt"] + payload["chosen_response"]
rejected_fmtd = payload["prompt"] + payload["rejected_response"]
logging.warning(
"Pre-formatting chat conversation as string with hardcoded chat tokens will be deprecated."
) # (@adithyare) this will spam the console for now.
else:
prompt_fmtd = self.convert(payload["prompt"]) # (@adithyare) read var as "prompt formatted"
chosen_fmtd = self.convert(payload["prompt"] + [payload["chosen_response"]])
rejected_fmtd = self.convert(payload["prompt"] + [payload["rejected_response"]])

prompt, prompt_len = self.encode(prompt_fmtd, append_eod=False)
chosen, chosen_len = self.encode(chosen_fmtd, append_eod=self.cfg.data.get("append_eod", False))
reject, reject_len = self.encode(rejected_fmtd, append_eod=self.cfg.data.get("append_eod", False))

# chosen_response_only, chosen_response_len = self.encode(payload['chosen_response'])
# reject_response_only, reject_response_len = self.encode(payload['rejected_response'])
chosen_labels = ([-100] * prompt_len) + chosen[prompt_len:]
Expand Down
68 changes: 68 additions & 0 deletions nemo_aligner/data/nlp/scripts/undo_special_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Script to remove special tokens from dpo datasets
and convert them into list of messages format"""

import argparse
import json
import re


def format_conversation(input_string):
# Define roles and patterns
role_patterns = {"<extra_id_0>System": "system", "<extra_id_1>User": "user", "<extra_id_1>Assistant": "assistant"}

# Initialize an empty output list
conversation = []

# Use regex to find each segment's role and content
segments = re.findall(r"(<extra_id_[0-1]>[^\n]+)\n(.*?)((?=<extra_id_)|$)", input_string, re.DOTALL)

for segment in segments:
role_tag, content, _ = segment
role = role_patterns.get(role_tag.strip(), "unknown")
conversation.append({"role": role, "content": content.strip()})

return conversation


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process a JSONL file.")
parser.add_argument("input_jsonl", type=str, help="Path to the input JSONL file.")
# Parse the arguments
args = parser.parse_args()

input_jsonl = args.input_jsonl
output_jsonl = input_jsonl.replace(".jsonl", ".no_special_toks.jsonl")

with open(input_jsonl, "r") as f, open(output_jsonl, "w") as w:
for line in f:
j = json.loads(line)
prompt = j["prompt"]
undo_spl_prompt = format_conversation(prompt)
empty_assistant = undo_spl_prompt.pop()
chosen, rejected = j["chosen_response"], j["rejected_response"]
chosen = chosen.split("\n<extra_id_1>")[0]
rejected = rejected.split("\n<extra_id_1>")[0]
chosen_message = {"role": empty_assistant["role"], "content": chosen}
rejected_message = {"role": empty_assistant["role"], "content": rejected}
j_out = {
"prompt": undo_spl_prompt,
"chosen_response": chosen_message,
"rejected_response": rejected_message,
"chosen_reward": j["chosen_reward"],
"rejected_reward": j["rejected_reward"],
}
w.write(json.dumps(j_out) + "\n")
1 change: 1 addition & 0 deletions setup/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Jinja2~=3.1.4
jsonlines
megatron_core>=0.8
nemo_toolkit[nlp]
Expand Down
51 changes: 51 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo_aligner.algorithms.dpo import dpo_custom_collate
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets
from nemo_aligner.data.nlp.scripts.undo_special_tokens import format_conversation
from nemo_aligner.utils import parallel_state


Expand Down Expand Up @@ -136,6 +137,56 @@ def test_dpo_loader(init_model_parallel, make_tmp_jsonl, llama3_tokenizer):
assert num_mini_batches == 2


@pytest.mark.run_only_on("GPU")
def test_dpo_dataset_conversion():
prompt_str = """<extra_id_0>System\n\n<extra_id_1>User\nDoes GPT-4 use RNN or Transformer models, and which one is better for this type of project?\n<extra_id_1>Assistant\nGPT-4 uses a transformer architecture, not a recurrent neural network. Both models are commonly used for natural language processing tasks, and both have advantages and disadvantages, so it depends on the specific application of the model.\n<extra_id_1>User\nCould you explain in detail both the advantages and disadvantages from different perspectives?\n<extra_id_1>Assistant\nYes, here are a few different perspectives on the pros and cons of RNN and transformer models:\n \n\n The RNN model has better sequential memory and can better model sequences of data, whereas the transformer model has better parallel computation and can better model large input datasets. This makes RNNs better for applications like speech recognition and machine translation, where the input is a sequence of sounds or words, and transformers better for applications like natural language processing and text classification, where the input is a large collection of documents or social media posts.\n \n\n RNNs are more straightforward and easier to understand and debug, whereas transformers are more flexible and can be used for more diverse applications. This makes RNNs better for teaching and research, and transformers better for actual applications in industry and the real world.\n \n\n RNNs are more efficient and faster to run, whereas transformers are more accurate and produce better results. This makes RNNs better for fast prototyping and testing, and transformers better for actual deployment in production applications.\n<extra_id_1>User\ncould you add more in a table\n<extra_id_1>Assistant\n"""

expected_oai_messages = [
{"role": "system", "content": ""},
{
"role": "user",
"content": "Does GPT-4 use RNN or Transformer models, and which one is better for this type of project?",
},
{
"role": "assistant",
"content": "GPT-4 uses a transformer architecture, not a recurrent neural network. Both models are commonly used for natural language processing tasks, and both have advantages and disadvantages, so it depends on the specific application of the model.",
},
{
"role": "user",
"content": "Could you explain in detail both the advantages and disadvantages from different perspectives?",
},
{
"role": "assistant",
"content": """Yes, here are a few different perspectives on the pros and cons of RNN and transformer models:\n \n\n The RNN model has better sequential memory and can better model sequences of data, whereas the transformer model has better parallel computation and can better model large input datasets. This makes RNNs better for applications like speech recognition and machine translation, where the input is a sequence of sounds or words, and transformers better for applications like natural language processing and text classification, where the input is a large collection of documents or social media posts.\n \n\n RNNs are more straightforward and easier to understand and debug, whereas transformers are more flexible and can be used for more diverse applications. This makes RNNs better for teaching and research, and transformers better for actual applications in industry and the real world.\n \n\n RNNs are more efficient and faster to run, whereas transformers are more accurate and produce better results. This makes RNNs better for fast prototyping and testing, and transformers better for actual deployment in production applications.""",
},
{"role": "user", "content": "could you add more in a table"},
{"role": "assistant", "content": ""},
]

oai_messages_prompt = format_conversation(prompt_str)
assert expected_oai_messages == oai_messages_prompt

# (@adithyare) bonus test! convert oai style messages back into a string using Jinja
# Attempt to import jinja2 via importorskip
jinja2 = pytest.importorskip("jinja2", reason="jinja2 library is not installed")

# Now it's safe to use jinja2
from jinja2 import Template

def remove_trailing(s, t):
if s.endswith(t):
s = s[: -len(t)]
return s

jinja_template = """{% for message in conversation %}{%- if message.role == "system" -%}<extra_id_0>System\n{{ message.content }}\n{% elif message.role == "user" -%}<extra_id_1>User\n{{ message.content }}\n{% elif message.role == "assistant" -%}<extra_id_1>Assistant\n{{ message.content }}\n{% endif %}{% endfor %}"""
jinja_template = Template(jinja_template)
prompt_str_jinja_rendered = jinja_template.render(conversation=oai_messages_prompt)
prompt_str_jinja_rendered = remove_trailing(
prompt_str_jinja_rendered, "\n"
) # (@adithyare) jinja will add the ending of message token which we should remove to make a prompt.
assert prompt_str == prompt_str_jinja_rendered


@pytest.mark.run_only_on("GPU")
def test_dpo_loader_original(init_model_parallel, make_tmp_jsonl, llama3_tokenizer):
init_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1)
Expand Down

0 comments on commit 5d4b2a7

Please sign in to comment.