Skip to content

Commit

Permalink
Tweak datamodule configs to use torchvision dir (#37)
Browse files Browse the repository at this point in the history
* Tweak datamodule configs to use torchvision dir

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix linting errors

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice authored Aug 16, 2024
1 parent 1241d58 commit 1a635ca
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 21 deletions.
12 changes: 0 additions & 12 deletions project/configs/datamodule/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,9 @@
from logging import getLogger as get_logger
from pathlib import Path

from hydra_zen import store

from project.utils.env_vars import NETWORK_DIR

logger = get_logger(__name__)

torchvision_dir: Path | None = None
"""Network directory with torchvision datasets."""
if (
NETWORK_DIR
and (_torchvision_dir := NETWORK_DIR / "datasets/torchvision").exists()
and _torchvision_dir.is_dir()
):
torchvision_dir = _torchvision_dir


# TODO: Make it possible to extend a structured base via yaml files as well as adding new fields
# (for example, ImagetNet32DataModule has a new constructor argument which can't be set atm in the
Expand Down
1 change: 1 addition & 0 deletions project/configs/datamodule/cifar10.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
defaults:
- vision
_target_: project.datamodules.CIFAR10DataModule
data_dir: ${constant:torchvision_dir,DATA_DIR}
batch_size: 128
train_transforms:
_target_: project.datamodules.image_classification.cifar10.cifar10_train_transforms
1 change: 1 addition & 0 deletions project/configs/datamodule/mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
defaults:
- vision
_target_: project.datamodules.MNISTDataModule
data_dir: ${constant:torchvision_dir,DATA_DIR}
normalize: True
batch_size: 128
train_transforms:
Expand Down
27 changes: 20 additions & 7 deletions project/configs/experiment/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,14 @@ defaults:
- override /network: resnet18
- override /trainer: default
- override /trainer/callbacks: default
- override /trainer/logger: wandb

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
name: example

seed: ${oc.env:SLURM_PROCID,12345}

trainer:
min_epochs: 1
max_epochs: 10
gradient_clip_val: 0.5


algorithm:
hp:
optimizer:
Expand All @@ -29,4 +25,21 @@ algorithm:
datamodule:
batch_size: 64

name: example

trainer:
min_epochs: 1
max_epochs: 10
gradient_clip_val: 0.5
logger:
wandb:
project: "ResearchTemplate"
name: ${oc.env:SLURM_JOB_ID}_${oc.env:SLURM_PROCID}
save_dir: "${hydra:runtime.output_dir}"
offline: False # set True to store all logs only locally
id: ${oc.env:SLURM_JOB_ID}_${oc.env:SLURM_PROCID} # pass correct id to resume experiment!
# entity: "" # set to name of your wandb team
log_model: False
prefix: ""
job_type: "train"
group: ${oc.env:SLURM_JOB_ID}
tags: ["${name}"]
37 changes: 35 additions & 2 deletions project/utils/env_vars.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import importlib
import os
from logging import getLogger as get_logger
from pathlib import Path

import torch

logger = get_logger(__name__)


SLURM_JOB_ID: int | None = (
int(os.environ["SLURM_JOB_ID"]) if "SLURM_JOB_ID" in os.environ else None
)
Expand Down Expand Up @@ -69,9 +74,37 @@
"""Local Directory where datasets should be extracted on this machine."""


def get_constant(name: str):
torchvision_dir: Path | None = None
"""Network directory with torchvision datasets."""
if (
NETWORK_DIR
and (_torchvision_dir := NETWORK_DIR / "datasets/torchvision").exists()
and _torchvision_dir.is_dir()
):
torchvision_dir = _torchvision_dir


def get_constant(*names: str):
"""Resolver for Hydra to get the value of a constant in this file."""
return globals()[name]
assert names
for name in names:
if name in globals():
obj = globals()[name]
if obj is None:
logger.debug(f"Value of {name} is None, moving on to the next value.")
continue
return obj
parts = name.split(".")
obj = importlib.import_module(parts[0])
for part in parts[1:]:
obj = getattr(obj, part)
if obj is not None:
return obj
logger.debug(f"Value of {name} is None, moving on to the next value.")

if len(names) == 1:
raise RuntimeError(f"Could not find non-None value for name {names[0]}")
raise RuntimeError(f"Could not find non-None value for names {names}")


NUM_WORKERS = int(
Expand Down

0 comments on commit 1a635ca

Please sign in to comment.