-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Remove `experiment.py`, move to `main.py` Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Use singledispatch for `train` and `evaluate` Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * WIP: Rework `main.py`, fix resulting errors Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Minor fix for doc / readability of main.py Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Use full path to datamodules in configs Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Save datamodule on self in text_classifier.py Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix project/main_test.py::test_help_string Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix test_help_string Signed-off-by: Fabrice Normandin <normandf@mila.quebec> --------- Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
- Loading branch information
Showing
22 changed files
with
438 additions
and
410 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
main is powered by Hydra. | ||
|
||
== Configuration groups == | ||
Compose your configuration from those groups (group=option) | ||
|
||
algorithm: image_classifier, jax_image_classifier, jax_ppo, llm_finetuning, no_op, text_classifier | ||
algorithm/lr_scheduler: CosineAnnealingLR, StepLR | ||
algorithm/network: fcnet, jax_cnn, jax_fcnet, resnet18, resnet50 | ||
algorithm/optimizer: Adam, SGD, custom_adam | ||
cluster: beluga, cedar, current, mila, narval | ||
datamodule: cifar10, fashion_mnist, glue_cola, imagenet, inaturalist, mnist, vision | ||
experiment: cluster_sweep_example, example, jax_rl_example, llm_finetuning_example, local_sweep_example, profiling, text_classification_example | ||
resources: cpu, gpu | ||
trainer: cpu, debug, default, jax_trainer, overfit_one_batch | ||
trainer/callbacks: default, early_stopping, model_checkpoint, model_summary, no_checkpoints, none, rich_progress_bar | ||
trainer/logger: tensorboard, wandb, wandb_cluster | ||
|
||
|
||
== Config == | ||
Override anything in the config (foo.bar=value) | ||
|
||
algorithm: ??? | ||
datamodule: null | ||
trainer: | ||
callbacks: | ||
model_checkpoint: | ||
_target_: lightning.pytorch.callbacks.ModelCheckpoint | ||
dirpath: ${hydra:runtime.output_dir}/checkpoints | ||
filename: epoch_{epoch:03d} | ||
monitor: val/loss | ||
verbose: false | ||
save_last: true | ||
save_top_k: 1 | ||
mode: min | ||
auto_insert_metric_name: false | ||
save_weights_only: false | ||
every_n_train_steps: null | ||
train_time_interval: null | ||
every_n_epochs: null | ||
save_on_train_epoch_end: null | ||
early_stopping: | ||
_target_: lightning.pytorch.callbacks.EarlyStopping | ||
monitor: val/loss | ||
min_delta: 0.0 | ||
patience: 5 | ||
verbose: false | ||
mode: min | ||
strict: true | ||
check_finite: true | ||
stopping_threshold: null | ||
divergence_threshold: null | ||
check_on_train_epoch_end: null | ||
model_summary: | ||
_target_: lightning.pytorch.callbacks.RichModelSummary | ||
max_depth: 2 | ||
rich_progress_bar: | ||
_target_: lightning.pytorch.callbacks.RichProgressBar | ||
lr_monitor: | ||
_target_: lightning.pytorch.callbacks.LearningRateMonitor | ||
device_utilisation: | ||
_target_: lightning.pytorch.callbacks.DeviceStatsMonitor | ||
throughput: | ||
_target_: project.algorithms.callbacks.samples_per_second.MeasureSamplesPerSecondCallback | ||
_target_: lightning.Trainer | ||
accelerator: auto | ||
strategy: auto | ||
devices: 1 | ||
deterministic: false | ||
fast_dev_run: false | ||
min_epochs: 1 | ||
max_epochs: 10 | ||
default_root_dir: ${hydra:runtime.output_dir} | ||
detect_anomaly: false | ||
log_level: info | ||
seed: 123 | ||
name: default | ||
debug: false | ||
verbose: false | ||
ckpt_path: null | ||
|
||
|
||
Powered by Hydra (https://hydra.cc) | ||
Use --hydra-help to view Hydra specific help |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +0,0 @@ | ||
from .image_classifier import ImageClassifier | ||
from .jax_image_classifier import JaxImageClassifier | ||
from .jax_ppo import JaxRLExample | ||
from .no_op import NoOp | ||
from .text_classifier import TextClassifier | ||
|
||
__all__ = [ | ||
"ImageClassifier", | ||
"JaxImageClassifier", | ||
"NoOp", | ||
"TextClassifier", | ||
"JaxRLExample", | ||
] | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
defaults: | ||
- mnist | ||
- _self_ | ||
_target_: project.datamodules.FashionMNISTDataModule | ||
_target_: project.datamodules.image_classification.fashion_mnist.FashionMNISTDataModule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
defaults: | ||
- vision | ||
- _self_ | ||
_target_: project.datamodules.ImageNetDataModule | ||
_target_: project.datamodules.image_classification.imagenet.ImageNetDataModule | ||
# todo: add good configuration options here. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
defaults: | ||
- vision | ||
- _self_ | ||
_target_: project.datamodules.INaturalistDataModule | ||
_target_: project.datamodules.image_classification.inaturalist.INaturalistDataModule | ||
version: "2021_train" | ||
target_type: "full" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.