Skip to content

Commit

Permalink
disable checkpointing when tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
thesofakillers committed Jan 16, 2023
1 parent 7483f6a commit 511e5d9
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion claficle/run/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main(cfg: DictConfig):
# data
oscar = OSCARDataModule(config=cfg.data, lang=lang, seed=cfg.seed)
tokenizer = transformers.AutoTokenizer.from_pretrained(cfg.model.causalLM_variant)
oscar.set_tokenizer(tokenizer) # necessary for collate_fn
oscar.set_tokenizer(tokenizer) # necessary for collate_fn

# set up pl trainer (tuner)
log_save_dir = os.path.join(
Expand Down Expand Up @@ -54,12 +54,15 @@ def main(cfg: DictConfig):
"log_every_n_steps": cfg.trainer.val_check_interval,
"callbacks": [timer],
}
else:
raise ValueError(f"Unknown tune_mode: {cfg.tune_mode}")
trainer = pl.Trainer(
max_epochs=1,
logger=logger,
enable_progress_bar=cfg.trainer.progress_bar,
accelerator=cfg.trainer.accelerator,
devices=cfg.trainer.devices,
enable_checkpointing=False,
**trainer_kwargs,
)
model.train_mode = cfg.trainer.train_mode
Expand Down

0 comments on commit 511e5d9

Please sign in to comment.