Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to provide prior NeMo 2 ckpt path to convert_nemo1_to_nemo… #11452

Merged
merged 4 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions scripts/checkpoint_converters/convert_nemo1_to_nemo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
--output_path=your_output_dir \
--model_id=meta-llama/Meta-Llama-3-8B

b. Convert a model weight directory.
b. Convert a model weight directory.
The checkpoint should be similar to `model_weights` subdir after extracting the .nemo file.
Please also provide tokenizer_library and tokenizer_path when loading from weight directory.
python /opt/NeMo/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py \
Expand All @@ -42,6 +42,7 @@
from pathlib import Path

import torch
from genericpath import isdir
Fixed Show fixed Hide fixed
from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace
from megatron.core.dist_checkpointing.mapping import LocalNonpersistentObject, ShardedObject
from omegaconf import OmegaConf
Expand Down Expand Up @@ -78,7 +79,7 @@
Parse the command line arguments.
"""
parser = ArgumentParser(
description="""Script to convert NeMo 1.0 checkpoints to NeMo 2.0 format.
description="""Script to convert NeMo 1.0 checkpoints to NeMo 2.0 format.
This script may download from Hugging Face, make sure you have
access to gate repo and have logged into Hugging Face (e.g. huggingface-cli login)"""
)
Expand All @@ -88,7 +89,7 @@
default=None,
required=True,
help="""Path to NeMo 1.0 checkpoints. Could be .nemo file, or `model_weights` directory a
fter untar the .nemo. Please also provide tokenizer_library and tokenizer_path if you pass
fter untar the .nemo. Please also provide tokenizer_library and tokenizer_path if you pass
in `model_weights` directory.""",
)
parser.add_argument(
Expand Down Expand Up @@ -123,6 +124,13 @@
Returns:
llm.GPTModel: NeMo 2.0 model instance
"""
if os.path.isdir(model_id):
from nemo.lightning import io

model = io.load_context(Path(model_id), subpath="model")
model.config.bf16 = True
model.config.params_dtype = torch.bfloat16
return model

if model_id not in MODEL_CONFIG_MAPPING:
valid_ids = "\n- ".join([""] + list(MODEL_CONFIG_MAPPING.keys()))
Expand Down
130 changes: 130 additions & 0 deletions scripts/llm/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTE: This script is just an example of using NeMo checkpoints
# for generating outputs and is subject to change without notice.

from argparse import ArgumentParser

import torch
import torch.distributed
from megatron.core.inference.common_inference_params import CommonInferenceParams

import nemo.lightning as nl
from nemo.collections.llm import api


def get_args():
"""
Parse the command line arguments.
"""
parser = ArgumentParser(description="""Run generation on a few sample prompts given the checkpoint path.""")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="""Path to NeMo 2 checkpoint""",
)
parser.add_argument(
"--tp",
type=int,
default=1,
help="""Tensor parallel size""",
)
parser.add_argument(
"--pp",
type=int,
default=1,
help="""Pipeline parallel size""",
)
parser.add_argument(
"--devices",
type=int,
default=1,
help="""Number of GPUs to use on a single node""",
)
parser.add_argument(
"--nodes",
type=int,
default=1,
help="""Number of nodes to use""",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="""Temperature to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""",
)
parser.add_argument(
"--top_p",
type=float,
default=0.95,
help="""top_p to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""",
)
parser.add_argument(
"--num_tokens_to_generate",
type=int,
default=4,
help="""Number of tokens to generate per prompt""",
)

args = parser.parse_args()
return args


if __name__ == "__main__":
args = get_args()

strategy = nl.MegatronStrategy(
tensor_model_parallel_size=args.tp,
pipeline_model_parallel_size=args.pp,
context_parallel_size=1,
sequence_parallel=False,
setup_optimizers=False,
store_optimizer_states=False,
)

trainer = nl.Trainer(
accelerator="gpu",
devices=args.devices,
num_nodes=args.nodes,
strategy=strategy,
plugins=nl.MegatronMixedPrecision(
precision="bf16-mixed",
params_dtype=torch.bfloat16,
pipeline_dtype=torch.bfloat16,
autocast_enabled=False,
grad_reduce_in_fp32=False,
),
)
prompts = [
"Hello, how are you?",
"How many r's are in the word 'strawberry'?",
"Which number is bigger? 10.119 or 10.19?",
]
results = api.generate(
path=args.model_path,
prompts=prompts,
trainer=trainer,
inference_params=CommonInferenceParams(
temperature=args.temperature, top_p=args.top_p, num_tokens_to_generate=args.num_tokens_to_generate
),
text_only=True,
)
if torch.distributed.get_rank() == 0:
for i, r in enumerate(results):
print(prompts[i])
print("*" * 50)
print(r)
print("\n\n")
66 changes: 0 additions & 66 deletions scripts/llm/llama3_generate.py

This file was deleted.

Loading