Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adithyare/mamba dpo #374

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/nlp/gpt/conf/gpt_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ model:
micro_batch_size: 1
global_batch_size: 64
megatron_amp_O2: True
mamba_hybrid: False

dpo:
# This default value ensures there are no numeric differences beween trained and reference policies when computing log probs.
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/gpt/conf/gpt_sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ model:
output_original_text: True # needed for the proper metrics support

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
lr: 3e-5
weight_decay: 0.01
betas:
Expand Down
4 changes: 2 additions & 2 deletions examples/nlp/gpt/train_gpt_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.dpo import DPOTrainer, dpo_custom_collate
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel, MegatronMambaDPOModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
Expand Down Expand Up @@ -53,7 +53,7 @@ def main(cfg) -> None:
logger = CustomLoggerWrapper(trainer.loggers)

ptl_model = load_from_nemo(
MegatronGPTDPOModel,
MegatronMambaDPOModel if cfg.model.mamba_hybrid else MegatronGPTDPOModel,
cfg.model,
trainer,
strict=True,
Expand Down
16 changes: 8 additions & 8 deletions examples/nlp/gpt/train_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.supervised import SupervisedTrainer
from nemo_aligner.data.nlp.builders import build_dataloader, build_sft_dataset
from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel
from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel, MambaSFTModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
Expand All @@ -39,8 +39,7 @@
resolve_and_create_trainer,
retrieve_custom_trainer_state_dict,
)
from nemo_aligner.utils.utils import load_from_nemo

from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo
"""Script to start SFT training"""

OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True)
Expand Down Expand Up @@ -115,6 +114,7 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):

@hydra_runner(config_path="conf", config_name="gpt_sft")
def main(cfg) -> None:
cfg.model = load_and_override_model_config(cfg.model.restore_from_path, cfg.model)
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f"\n{OmegaConf.to_yaml(cfg)}")

Expand All @@ -126,17 +126,15 @@ def main(cfg) -> None:
with open_dict(cfg):
cfg.model.precision = cfg.trainer.precision

ptl_model, updated_cfg = load_from_nemo(
GPTSFTModel,
ptl_model = load_from_nemo(
MambaSFTModel if cfg.model.mamba_hybrid else GPTSFTModel,
cfg,
trainer,
strict=True,
modify_config_fn=_modify_config,
restore_path=cfg.model.restore_from_path,
return_updated_cfg=True,
)

init_peft(ptl_model, updated_cfg)
init_peft(ptl_model, cfg.model)

with open_dict(cfg):
# overwrite the model config with the config from the checkpoint
Expand Down Expand Up @@ -170,6 +168,7 @@ def main(cfg) -> None:
train_data_cfg,
ptl_model.tokenizer,
num_samples,
is_mamba=cfg.model.mamba_hybrid,
answer_only_loss=True,
is_chat=cfg.model.data.chat,
special_tokens=cfg.model.data.chat_prompt_tokens,
Expand All @@ -182,6 +181,7 @@ def main(cfg) -> None:
val_data_cfg,
ptl_model.tokenizer,
num_samples,
is_mamba=cfg.model.mamba_hybrid,
answer_only_loss=True,
is_chat=cfg.model.data.chat,
special_tokens=cfg.model.data.chat_prompt_tokens,
Expand Down
18 changes: 18 additions & 0 deletions nemo_aligner/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@
from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches, compute_num_steps_per_epoch
from nemo_aligner.utils.utils import clear_memory

def pad_sequence_to_max(sequences, max_len, padding_value=0):
# Then, pad further to match `max_len`
if sequences.size(1) > max_len:
raise RuntimeError("max len has to be > seq len")
elif sequences.size(1) <= max_len:
pad_size = max_len - sequences.size(1)
padding = torch.full((sequences.size(0), pad_size), padding_value)
padded_sequences = torch.cat([sequences, padding], dim=1)
return padded_sequences

def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False):
chosen_tokens = [item["chosen"] for item in batch]
Expand Down Expand Up @@ -317,6 +326,15 @@ def augment_dataloader(self, dataloader):
while True:
try:
batch = next(iter_dataloader)
if self.model.cfg.mamba_hybrid:
max_seq_len = max([batch['chosen'].size(-1), batch['rejected'].size(-1), batch['chosen_labels'].size(-1), batch['rejected_labels'].size(-1)])
max_seq_len = torch.tensor(max_seq_len, device=torch.cuda.current_device())
torch.distributed.all_reduce(max_seq_len, op=torch.distributed.ReduceOp.MAX)
max_seq_len = ((max_seq_len.item() + 255) // 256) * 256
batch["chosen"] = pad_sequence_to_max(batch["chosen"], max_seq_len, padding_value=self.model.tokenizer.eos_id)
batch["chosen_labels"] = pad_sequence_to_max(batch["chosen_labels"], max_seq_len, padding_value=-100)
batch["rejected"] = pad_sequence_to_max(batch["rejected"], max_seq_len, padding_value=self.model.tokenizer.eos_id)
batch["rejected_labels"] = pad_sequence_to_max(batch["rejected_labels"], max_seq_len, padding_value=-100)
logprobs = self.model.get_ref_policy_logprobs(batch).cpu()
chosen_logps, reject_logps = torch.split(logprobs, len(logprobs) // 2, dim=0)
batch["ref_policy_log_probs_chosen"] = chosen_logps
Expand Down
3 changes: 2 additions & 1 deletion nemo_aligner/data/nlp/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def build_dataset(index, name):
build_train_valid_test_regression_rm_datasets = partial(build_train_valid_test_datasets, RegressionRewardModelDataset)


def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, is_chat=True, special_tokens=None):
def build_sft_dataset(data_cfg, tokenizer, num_samples, is_mamba, answer_only_loss=True, is_chat=True, special_tokens=None):
packed_sequence = data_cfg.get("packed_sequence", False)
dataset_kwargs = {}

Expand Down Expand Up @@ -298,6 +298,7 @@ def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, i
answer_only_loss=answer_only_loss,
truncation_field=data_cfg.get("truncation_field", "text"),
pad_to_max_length=data_cfg.get("pad_to_max_length", False),
pad_seq_length_to_mult=256 if is_mamba else 16,
index_mapping_dir=data_cfg.get("index_mapping_dir", None),
prompt_template=data_cfg.get("prompt_template", None),
virtual_tokens=0,
Expand Down
70 changes: 70 additions & 0 deletions nemo_aligner/data/nlp/scripts/undo_special_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Script to remove special tokens from dpo datasets
and convert them into list of messages format"""

import json
import re
import sys
input_jsonl = sys.argv[1]
output_jsonl = input_jsonl.replace(".jsonl", ".no_special_toks.jsonl")

def format_conversation(input_string):
# Define roles and patterns
role_patterns = {
"<extra_id_0>System": "system",
"<extra_id_1>User": "user",
"<extra_id_1>Assistant": "assistant"
}

# Initialize an empty output list
conversation = []

# Use regex to find each segment's role and content
segments = re.findall(r"(<extra_id_[0-1]>[^\n]+)\n(.*?)((?=<extra_id_)|$)", input_string, re.DOTALL)

for segment in segments:
role_tag, content, _ = segment
role = role_patterns.get(role_tag.strip(), "unknown")
conversation.append({"role": role, "content": content.strip()})

empty_asst = conversation.pop()

return conversation, empty_asst

with open(input_jsonl, "r") as f, open(output_jsonl, "w") as w:
for line in f:
j = json.loads(line)
prompt = j["prompt"]
undo_spl_prompt, empty_assistant = format_conversation(prompt)
chosen, rejected = j["chosen_response"], j["rejected_response"]
chosen = chosen.split("\n<extra_id_1>")[0]
rejected = rejected.split("\n<extra_id_1>")[0]
chosen_message = {"role": empty_assistant["role"], "content": chosen}
rejected_message = {"role": empty_assistant["role"], "content": rejected}
j_out = {"prompt": undo_spl_prompt, "chosen_response": chosen_message, "rejected_response": rejected_message, "chosen_reward": j["chosen_reward"], "rejected_reward": j["rejected_reward"]}
w.write(json.dumps(j_out) + "\n")

old_format_example="""<extra_id_0>System\n\n<extra_id_1>User\nDoes GPT-4 use RNN or Transformer models, and which one is better for this type of project?\n<extra_id_1>Assistant\nGPT-4 uses a transformer architecture, not a recurrent neural network. Both models are commonly used for natural language processing tasks, and both have advantages and disadvantages, so it depends on the specific application of the model.\n<extra_id_1>User\nCould you explain in detail both the advantages and disadvantages from different perspectives?\n<extra_id_1>Assistant\nYes, here are a few different perspectives on the pros and cons of RNN and transformer models:\n \n\n The RNN model has better sequential memory and can better model sequences of data, whereas the transformer model has better parallel computation and can better model large input datasets. This makes RNNs better for applications like speech recognition and machine translation, where the input is a sequence of sounds or words, and transformers better for applications like natural language processing and text classification, where the input is a large collection of documents or social media posts.\n \n\n RNNs are more straightforward and easier to understand and debug, whereas transformers are more flexible and can be used for more diverse applications. This makes RNNs better for teaching and research, and transformers better for actual applications in industry and the real world.\n \n\n RNNs are more efficient and faster to run, whereas transformers are more accurate and produce better results. This makes RNNs better for fast prototyping and testing, and transformers better for actual deployment in production applications.\n<extra_id_1>User\ncould you add more in a table\n<extra_id_1>Assistant\nHere are some more pros and cons of RNN and transformer models:\n \n\n Pros of RNNs:\n \n\n - Better sequential memory and handling of sequences\n - Simpler and easier to understand\n - Faster to run\n - Good for teaching and research\n \n\n Cons of RNNs:\n \n\n - Limited parallelism\n - Poor handling of large input datasets\n - Not as accurate as transformers\n - Not as flexible as transformers\n \n\n Pros of transformers:\n \n\n - Better parallelism and handling of large input datasets\n - More flexible and versatile\n - Higher accuracy\n - Good for actual applications in industry and the real world\n \n\n Cons of transformers:\n \n\n - More complex and harder to understand and debug\n - Slower to run\n - Not as good for teaching and research\n<extra_id_1>User\nFor a chat-based music recommendation platform, which model is recommended? Please mention any model that works well, including those mentioned above or any others.\n<extra_id_1>Assistant\nFor a chat-based music recommendation platform, I would recommend using a transformer architecture, such as the GPT-4 model. Transformers are better for handling large input datasets, and they are more flexible and versatile than RNNs. They are also more accurate, so they are better for actually recommending music to users in a real-world application. RNNs are better for teaching and research, as they are simpler and easier to understand and debug.\n<extra_id_1>User\ncould you add more?\n<extra_id_1>Assistant\n"""

print("****************************************")
print("old format prompt")
print("****************************************")
print(old_format_example)
print("****************************************")
print("new format prompt")
print("****************************************")
new_format_example, _ = format_conversation(old_format_example)
print(json.dumps(new_format_example, indent=2))
6 changes: 6 additions & 0 deletions nemo_aligner/models/nlp/gpt/gpt_sft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel
from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split
from nemo.collections.nlp.modules.common.text_generation_strategy import TextGenerationStrategy
from nemo.collections.nlp.modules.common.text_generation_utils import (
Expand Down Expand Up @@ -225,3 +226,8 @@ def finish_inference(self):
self._restore_activation_checkpointing_args()
self._restore_sequence_parallelism_args()
set_train(self)


class MambaSFTModel(MegatronMambaModel, GPTSFTModel):
def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer=trainer)
7 changes: 7 additions & 0 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
from functools import partial

import torch
from megatron.core import parallel_state
from megatron.core.models.mamba import MambaModel
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.utils import divide
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel
from nemo.collections.nlp.modules.common.megatron.utils import (
average_losses_across_data_parallel_group,
get_iterator_k_split,
Expand Down Expand Up @@ -460,3 +463,7 @@ def get_ref_policy_logprobs(self, batch):

# return in GPU, trainer needs to move to cpu
return ref_log_probs

class MegatronMambaDPOModel(MegatronMambaModel, MegatronGPTDPOModel): # @adithyare inherence order matters
arendu marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer=trainer)
5 changes: 3 additions & 2 deletions nemo_aligner/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import torch
from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensorFactory
from megatron.core.num_microbatches_calculator import reconfigure_microbatch_calculator
from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator as reconfigure_microbatch_calculator
from omegaconf import DictConfig, OmegaConf
from torch.masked import as_masked_tensor

Expand Down Expand Up @@ -122,7 +122,8 @@ def load_checkpoint_model_config(restore_path):
return OmegaConf.load(cfg_path)

with tempfile.TemporaryDirectory() as tmpdir:
NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, extract_config_only=True)
members = NLPSaveRestoreConnector._filtered_tar_info(restore_path, filter_fn=lambda name: '.yaml' in name)
NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, members=members)
cfg = OmegaConf.load(os.path.join(tmpdir, config_name_in_ckpt))

return cfg
Expand Down
Loading