Skip to content

Commit

Permalink
fix: bug fix for KD + PP (#443)
Browse files Browse the repository at this point in the history
Signed-off-by: ashors1 <ashors@nvidia.com>
  • Loading branch information
ashors1 authored Dec 11, 2024
1 parent 4830a07 commit 2ead6bf
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
required_keys.update(("tokens", "position_ids"))

if parallel_state.is_pipeline_last_stage():
required_keys.update(("labels", "loss_mask"))
required_keys.update(("labels", "loss_mask", "topk_logits", "topk_token_ids"))

batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()}

Expand All @@ -83,7 +83,9 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_

tokens = batch["tokens"]
labels = batch["labels"]
loss_mask = batch["loss_mask"].clamp(min=0, max=1)
loss_mask = batch["loss_mask"]
if loss_mask is not None:
loss_mask = loss_mask.clamp(min=0, max=1)
target_topk_logits = batch["topk_logits"]
target_topk_token_ids = batch["topk_token_ids"]
# Model forward pass
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/kd.sh
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ torchrun --nproc-per-node 2 ${GPFS}/examples/nlp/gpt/train_gpt_knowledge_distill
exp_manager.create_checkpoint_callback=False \
model.data.num_workers=2 \
++model.tensor_model_parallel_size=1 \
++model.pipeline_model_parallel_size=1 \
++model.pipeline_model_parallel_size=2 \
exp_manager.explicit_log_dir=${RESULTS_DIR} \
++model.activations_checkpoint_granularity=full \
++model.activations_checkpoint_method=uniform \
Expand Down

0 comments on commit 2ead6bf

Please sign in to comment.