Skip to content

Commit

Permalink
Add disabling checkpoint
Browse files Browse the repository at this point in the history
During parallel training on different KGs simultaneously on an HPC cluster, checkpointing can result in models looking at different checkpoints of diferent KGs which results in inconsistency.

This commit is to avoid that inconsistency.
  • Loading branch information
sshivam95 committed Jun 26, 2024
1 parent 047db5a commit b5925dc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions dicee/scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def get_default_arguments(description=None):
help="Stochastic weight averaging")
parser.add_argument('--degree', type=int, default=0,
help='degree for polynomial embeddings')
parser.add_argument('--disable_checkpointing', action='store_false', help='Disable creation of checkpoints during training')

if description is None:
return parser.parse_args()
Expand Down
3 changes: 2 additions & 1 deletion dicee/trainer/dice_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def initialize_trainer(args, callbacks):
max_steps=kwargs.get("max_step", -1),
min_steps=kwargs.get("min_steps", None),
detect_anomaly=False,
barebones=False)
barebones=False,
enable_checkpointing=kwargs['disable_checkpointing'])
else:
print('Initialize TorchTrainer CPU Trainer', end='\t')
return TorchTrainer(args, callbacks=callbacks)
Expand Down

0 comments on commit b5925dc

Please sign in to comment.