Skip to content

Commit

Permalink
fix: adds missing support for mcore dist opt and adds test for moe
Browse files Browse the repository at this point in the history
Signed-off-by: Terry Kong <terryk@nvidia.com>
  • Loading branch information
terrykong committed Nov 7, 2024
1 parent 0806f7d commit c18de8d
Show file tree
Hide file tree
Showing 21 changed files with 171 additions and 45 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ jobs:
- dpo-llama3
- sft-llama3
- rm-llama3
- dpo-mixtral-ep
- dpo-mixtral-peft-tp-sp
with:
RUNNER: self-hosted-azure
# Fairly aggresive timeout that all functional tests should try to adhere to
Expand Down
6 changes: 5 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,20 @@ git fetch -a
# 60e677423667c029dd05875da72bf0719774f844: [feat] Update get_model_parallel_src_rank to support tp-pp-dp ordering NeMo#10652
# 0deaf6716cb4f20766c995ce25d129795f1ae200: fix[export]: update API for disabling device reassignment in TRTLLM for Aligner NeMo#10863
# (superceded by 10863) 148543d6e9c66ff1f8562e84484448202249811d: feat: Migrate GPTSession refit path in Nemo export to ModelRunner for Aligner NeMo#10654
# ba8edbd2063f3349c40c9c73e5bae46abbe65f94: fix: regular torch optims (e.g., sgd) no longer error with closure spec NeMo#11189
# 35a7f718237cf011215db9e92273ed7236d0e8b1: Fix for crash with LoRA + tp_overlap_comm=false + sequence_parallel=true NeMo#10920
for pr_and_commit in \
"10651 0c92fe17df4642ffc33d5d8c0c83fda729e3910c" \
"10652 60e677423667c029dd05875da72bf0719774f844" \
"10863 0deaf6716cb4f20766c995ce25d129795f1ae200" \
"11189 ba8edbd2063f3349c40c9c73e5bae46abbe65f94" \
"10920 53cf6527571b29379188c8bb0dba8e507db3cca1" \
; do
pr=$(cut -f1 -d' ' <<<"$pr_and_commit")
head_pr_commit=$(cut -f2 -d' ' <<<"$pr_and_commit")
git fetch origin $head_pr_commit:PR-${pr}
# cherry-picks all commits between main and the top of the PR
git cherry-pick --allow-empty $(git merge-base origin/main PR-${pr})..PR-${pr}
git cherry-pick -m 1 --allow-empty $(git merge-base origin/main PR-${pr})..PR-${pr}
# Tag cherry-picks to help
git tag cherry-pick-PR-${pr}
done
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# dpo specific args
dpo:
Expand All @@ -17,6 +18,7 @@ trainer:

# how many GBS we loop over
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_kto.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# kto specific args
kto:
Expand All @@ -17,6 +18,7 @@ trainer:

# how many GBS we loop over
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_ppo_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

ppo:
# How many steps we train warmup the critic for (without training the policy)
Expand All @@ -21,6 +22,7 @@ trainer:
max_steps: -1 # max PPO steps (-1 to go through the whole train set)
val_check_interval: 10
save_interval: ${.val_check_interval}
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# PPO args to generate the data for training
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_ppo_critic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

ppo:
port: 5556
Expand All @@ -15,6 +16,7 @@ trainer:

# used to set the learning rate scheduler
max_steps: 10000
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# a PyTriton parameter to specify
Expand Down
4 changes: 3 additions & 1 deletion examples/nlp/gpt/conf/gpt_rs_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

rs:
max_epochs: 1
max_steps: -1 # max rs steps (-1 to go through the whole train set)
val_check_interval: 10
save_interval: ${.val_check_interval}
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# pick up from the model
Expand Down Expand Up @@ -177,4 +179,4 @@ model:
# define fields from the base model's config that should be ignored when merging with this config.
overwrite_base_config:
data:
data_prefix: True
data_prefix: True
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ trainer:
devices: 1
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

sft:
max_epochs: 1
Expand All @@ -15,6 +16,7 @@ trainer:
limit_train_batches: 1.0

limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# can be used to register any custom metrics that require token-by-token generation
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_spin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16-mixed
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# spin specific args
spin:
Expand All @@ -18,6 +19,7 @@ trainer:

# how many GBS we loop over
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/training_rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# rm specific args
rm:
Expand All @@ -20,6 +21,7 @@ trainer:
# set to float for a percentage
# of the validation dataset
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/critic_server_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def run_training(self, tokens=None, returns=None, prev_values=None, mask=None):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

if grad_norm is not None:
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def train_single_step(self, global_batch):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

trainer_metrics = {}
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def run_training(self, dataloader_iter):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

if grad_norm is not None:
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def run_training(self, dataloader_iter):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

if grad_norm is not None:
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def train_single_step(self, global_batch):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

trainer_metrics = {}
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def train_single_step(self, batch):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

trainer_metrics = {}
Expand Down
56 changes: 44 additions & 12 deletions nemo_aligner/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,31 +101,52 @@ def prepare_for_training_step(ptl_model, zero_grad=True):
param.data_ptr()


# TODO: Delete this once API introduced in NeMo (https://github.com/NVIDIA/NeMo/pull/10803)
# TODO: Update PR to move this logic into staticmethod in nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
def grad_reductions(ptl_model):
# when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced
if ptl_model.cfg.get("tensor_model_parallel_size", 1) > 1 and ptl_model.cfg.get("sequence_parallel", False):
ptl_model.allreduce_sequence_parallel_gradients()

if ptl_model.with_distributed_adam:
# synchronize asynchronous grad reductions
# note: not necessary, but reduces performance degradation
# from multiple simultaneous NCCL calls
ptl_model._optimizer._finish_bucket_grad_sync()
# Mcore DistOpt handles this, so we don't have to
if not ptl_model.use_mcore_dist_optim:
ptl_model.megatron_timer_start("allreduce_sequence_parallel_gradients", log_level=1)
ptl_model.allreduce_sequence_parallel_gradients()
ptl_model.megatron_timer_stop("allreduce_sequence_parallel_gradients")

ptl_model.megatron_timer_start("gradient_allreduce", log_level=1)
if ptl_model.use_fsdp:
# Reduce the gradients omitted from FSDP-sharding
ptl_model.allreduce_fsdp_sharding_omitted_gradients()
elif ptl_model.with_distributed_adam:
if not ptl_model.use_mcore_dist_optim:
# synchronize asynchronous grad reductions
# note: not necessary, but reduces performance degradation
# from multiple simultaneous NCCL calls
ptl_model._optimizer._finish_bucket_grad_sync()
# else: Mcore distributed optim calls finalize_model_grads to finish grad sync
elif ptl_model.megatron_amp_O2:
# when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously)
if ptl_model.cfg.get("pipeline_model_parallel_size", 1) > 1 or ptl_model.cfg.get("sequence_parallel", False):
if (
ptl_model.cfg.get("pipeline_model_parallel_size", 1) > 1
or ptl_model.cfg.get("sequence_parallel", False)
or not ptl_model.cfg.get("async_grad_allreduce", True)
):
# main grads are stored in the MainParamsOptimizer wrapper
ptl_model._optimizer.allreduce_main_grads()
else:
# async grad allreduce is not currently implemented for O1/autocasting mixed precision training
# so we all-reduce gradients after the pipeline
ptl_model.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf)
ptl_model.megatron_timer_stop("gradient_allreduce")

if ptl_model.cfg.get("pipeline_model_parallel_size", 1) > 1 and ptl_model.cfg.get(
"share_embeddings_and_output_weights", True
if (
not ptl_model.use_mcore_dist_optim
and ptl_model.cfg.get("pipeline_model_parallel_size", 1) > 1
and ptl_model.cfg.get("share_embeddings_and_output_weights", True)
):
ptl_model.megatron_timer_start("allreduce_first_last_embeddings", log_level=1)
# when using pipeline parallelism the first and last stage must keep embeddings in sync
ptl_model.allreduce_first_last_embeddings()
ptl_model.megatron_timer_stop("allreduce_first_last_embeddings")


def prepare_for_validation_step(ptl_model):
Expand Down Expand Up @@ -155,14 +176,26 @@ def set_eval(ptl_model):
ptl_model.eval()


# TODO: adapt the version in /opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_base_model.py
def clip_gradients(ptl_model, clip_val):
"""PTL hook to configure gradients.
We use gradient clipping implementation from megatron-lm.
"""
if clip_val is None:
return

clip_val = float(clip_val)
if clip_val <= 0:
return

if ptl_model.with_megatron_fused_adam or ptl_model.use_mcore_dist_optim:
# Gradient clipping is done in optimizer step
return

if ptl_model.grad_clip_pl_default:
# use the default behavior
return super().configure_gradient_clipping(*args, **kwargs)

if ptl_model.with_distributed_adam:
grad_norm = clip_grad_norm_distributed_optimizer(ptl_model._optimizer, clip_val)
else:
Expand All @@ -171,6 +204,5 @@ def clip_gradients(ptl_model, clip_val):
parameters = ptl_model._optimizer.get_parameters_with_grad()
else:
parameters = ptl_model.get_parameters_with_grad()
grad_norm = clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val)

grad_norm = clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val, use_fsdp=ptl_model.use_fsdp,)
return grad_norm
Loading

0 comments on commit c18de8d

Please sign in to comment.