Skip to content

Commit

Permalink
More detailed logging, and save failed checkpoints.
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgiosSmyrnis committed May 14, 2024
1 parent 5d9d0d0 commit 64a8554
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
11 changes: 10 additions & 1 deletion open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,15 @@ def save_checkpoint(
scaler,
completed_epoch,
evaluation_metrics,
percentage_of_data_seen,
step,
is_final_checkpoint,
next_shard_per_source=None,
samples_seen=None,
shard_shuffle_seed=None,
train_data_string=None,
averagers=None,
failed=False,
):
cpu_state, optim_state = None, None
if args.logs and args.logs.lower() != "none" and args.fsdp:
Expand Down Expand Up @@ -247,6 +249,7 @@ def save_checkpoint(
"name": args.name,
"is_final_checkpoint": is_final_checkpoint,
"evaluation_metrics": evaluation_metrics,
"percentage_of_data_seen": percentage_of_data_seen,
}
if next_shard_per_source is not None:
checkpoint_dict_stats["next_shard_per_source"] = next_shard_per_source
Expand Down Expand Up @@ -278,7 +281,8 @@ def save_checkpoint(
or (args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0)
):
for prefix in prefixes:
path = os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch}.pt")
save_path = args.checkpoint_path if not failed else args.failed_checkpoint_path
path = os.path.join(save_path, f"{prefix}{completed_epoch}.pt")
print(f"Saving {prefix}{completed_epoch} in {path}...")
torch.save(
prefixes[prefix],
Expand Down Expand Up @@ -375,6 +379,7 @@ def main(args):
args.wandb = "wandb" in args.report_to or "all" in args.report_to
args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to
args.checkpoint_path = os.path.join(log_base_path, "checkpoints")
args.failed_checkpoint_path = os.path.join(log_base_path, "checkpoints_failed")
if is_master(args):
args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else ""
for dirname in [args.tensorboard_path, args.checkpoint_path]:
Expand Down Expand Up @@ -840,8 +845,10 @@ def main(args):
logging.info("Training exiting due to NaN value")
break

failed_ckpt = False
expected_steps = data["train"].dataloader.num_batches
if steps_done_epoch < (1 - args.data_tolerate_error_p) * expected_steps and not done_training:
failed_ckpt = True
num_ckpt_too_few_tokens += 1
if is_master(args):
logging.warning(
Expand Down Expand Up @@ -901,13 +908,15 @@ def main(args):
scaler,
epoch,
evaluation_metrics,
percentage_of_data_seen=1.0 * steps_done_epoch / expected_steps,
step=global_step,
is_final_checkpoint=done_training,
next_shard_per_source=next_shard_per_source if args.dataset_manifest is not None else None,
samples_seen=samples_seen if args.dataset_manifest is not None else None,
shard_shuffle_seed=args.shard_shuffle_seed,
train_data_string=train_data_string_per_source if args.dataset_manifest is not None else None,
averagers=averagers,
failed=failed_ckpt,
)

if done_training:
Expand Down
2 changes: 2 additions & 0 deletions open_lm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ def train_one_epoch(
"samples_per_second_per_gpu": samples_per_second_per_gpu,
"lr": optimizer.param_groups[0]["lr"],
"tokens": (step + 1) * args.global_batch_size * args.seq_len,
"expected_steps_epoch": data["train"].dataloader.num_batches,
"seen_steps_epoch": batch_count
}

if averagers is not None and args.log_avg_model_training_loss:
Expand Down

0 comments on commit 64a8554

Please sign in to comment.