From 3d1413cf5bc836b64c71eecfe3d3e55f5c80be0a Mon Sep 17 00:00:00 2001 From: Giulio Starace <26286291+thesofakillers@users.noreply.github.com> Date: Mon, 16 Jan 2023 12:53:51 +0100 Subject: [PATCH] we'll just use the gewechselt.py file's main to init gewechselt models --- claficle/run/wechsel_init.py | 78 ------------------------------- slurm/wechsel_init/init.array.job | 2 +- 2 files changed, 1 insertion(+), 79 deletions(-) delete mode 100644 claficle/run/wechsel_init.py diff --git a/claficle/run/wechsel_init.py b/claficle/run/wechsel_init.py deleted file mode 100644 index 27bd9c2..0000000 --- a/claficle/run/wechsel_init.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -WECHSEL initializations involve the training of a tokenizer, and can therefore -be a lengthy process on their own. This script separates that process -""" -import os - -from omegaconf import DictConfig, OmegaConf -import transformers -import pytorch_lightning as pl -import torch -import hydra - -from claficle.data.oscar import OSCARDataModule -from claficle.models.gewechselt import Gewechselt - - -@hydra.main(version_base=None, config_path="../conf", config_name="wechsel_init") -def main(cfg: DictConfig): - pl.seed_everything(cfg.seed) - - # we'll need a Trainer instance to save a checkpoint - log_save_dir = os.path.join( - cfg.trainer.log_dir, cfg.model.name, f"seed_{cfg.seed}", cfg.model.target_lang - ) - os.makedirs(log_save_dir, exist_ok=True) - script_host = "slurm" if "SLURM_JOB_ID" in os.environ else "local" - logger = pl.loggers.WandbLogger( - save_dir=log_save_dir, - job_type="wechsel_init", - project="claficle", - entity="giulio-uva", - mode="disabled" if cfg.trainer.disable_wandb else "online", - group=script_host, - config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True), - ) - trainer = pl.Trainer( - max_epochs=1, - logger=logger, - enable_progress_bar=cfg.trainer.progress_bar, - accelerator=cfg.trainer.accelerator, - devices=cfg.trainer.devices, - ) - - tokenizer = transformers.AutoTokenizer.from_pretrained(cfg.model.causalLM_variant) - tokenizer.pad_token = tokenizer.eos_token - - oscar = OSCARDataModule(config=cfg.data, lang=cfg.model.target_lang, seed=cfg.seed) - oscar.set_tokenizer(tokenizer) - oscar.prepare_data() - oscar.setup() - - # this will take a while - model: Gewechselt = Gewechselt(cfg.model, oscar.train_dataset) - - # just so that we can save a PL checkpoint of the model - trainer.predict( - model, - dataloaders=torch.utils.data.DataLoader( - oscar.val_dataset_tokens.select([1]), - batch_size=1, - collate_fn=oscar.collate_fn, - ), - return_predictions=False, - ) - - # save the checkpoint - prefix = ( - cfg.model.base_checkpoint.split(".")[0] - if cfg.model.base_checkpoint is not None - else cfg.model.causalLM_variant - ) - trainer.save_checkpoint( - "checkpoints/" f"{prefix}_{cfg.model.name}_{cfg.model.target_lang}_init.ckpt" - ) - - -if __name__ == "__main__": - main() diff --git a/slurm/wechsel_init/init.array.job b/slurm/wechsel_init/init.array.job index 555836e..c472763 100644 --- a/slurm/wechsel_init/init.array.job +++ b/slurm/wechsel_init/init.array.job @@ -20,5 +20,5 @@ source activate claficle HPARAMS_FILE=slurm/wechsel_init/init.array.txt -srun python -u claficle/run/wechsel_init.py \ +srun python -u claficle/models/gewechselt.py \ $(head -$SLURM_ARRAY_TASK_ID $HPARAMS_FILE | tail -1)