Skip to content

Commit

Permalink
Simplify typing, remove uses of the Algorithm class (#32)
Browse files Browse the repository at this point in the history
* Rename `jax_algo.py` to `jax_example.py`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Simplify Example algo, remove Algorithm as base

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove Algorithm base from Jax example algo

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove use of the `Algorithm` class in all algos

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Make `Algorithm` a protocol class

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix import bug with `Algorithm` class

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* fixup! Remove use of the `Algorithm` class in all algos

* Fix bug in main_test.py

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issue with config name for jax example

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Rename `example_algo.py` to `example.py`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Move test classes to a new `testsuites` folder

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove `Algorithm` class use from `callback.py`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove other references to `Algorithm` class

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Move `Algorithm` to the testsuites folder

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Rename Jax example, remove PhaseStr

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove the "manual optimization" example

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Clean up the configs for algorithms

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Create dyn. configs for optimizers and schedulers

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix a bug in hydra-zen for inner classes

- Adds a patch for mit-ll-responsible-ai/hydra-zen#705

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix the no_op algo constructor

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix the configs for algorithms

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix uses of the algorithm config name in tests

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add tests for the optimizer & lr_scheduler configs

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issues with dynamic configs

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issue with `instantiate` in hydra_utils

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issue with required args in configs

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice authored Jul 15, 2024
1 parent d4891fb commit 0d35a38
Show file tree
Hide file tree
Showing 35 changed files with 848 additions and 627 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ default_language_version:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.6.0
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
Expand Down Expand Up @@ -32,7 +32,7 @@ repos:

- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: "v0.3.3"
rev: "v0.5.1"
hooks:
- id: ruff
args: ['--line-length', '99', '--fix']
Expand All @@ -41,7 +41,7 @@ repos:

# python docstring formatting
- repo: https://github.com/myint/docformatter
rev: v1.5.1
rev: v1.7.5
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99]
Expand All @@ -64,7 +64,7 @@ repos:

# Dependency management
- repo: https://github.com/pdm-project/pdm
rev: 2.12.4
rev: 2.16.1
hooks:
- id: pdm-lock-check
require_serial: true
Expand All @@ -91,7 +91,7 @@ repos:

# word spelling linter
- repo: https://github.com/codespell-project/codespell
rev: v2.2.2
rev: v2.3.0
hooks:
- id: codespell
args:
Expand Down
8 changes: 5 additions & 3 deletions project/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from . import algorithms, configs, datamodules, experiment, main, networks, utils
from .algorithms import Algorithm
from .configs import Config
from .configs import Config, add_configs_to_hydra_store
from .experiment import Experiment
from .utils.hydra_utils import patched_safe_name # noqa

# from .networks import FcNet
from .utils.types import DataModule

add_configs_to_hydra_store()


__all__ = [
"algorithms",
"experiment",
Expand All @@ -14,7 +17,6 @@
"configs",
"datamodules",
"networks",
"Algorithm",
"DataModule",
"utils",
# "ExampleAlgorithm",
Expand Down
28 changes: 4 additions & 24 deletions project/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,10 @@
from hydra_zen import builds, store

from project.algorithms.jax_algo import JaxAlgorithm
from project.algorithms.jax_example import JaxExample
from project.algorithms.no_op import NoOp

from .algorithm import Algorithm
from .example_algo import ExampleAlgorithm
from .manual_optimization_example import ManualGradientsExample

# NOTE: This works the same way as creating config files for each algorithm under
# `configs/algorithm`. From the command-line, you can select both configs that are yaml files as
# well as structured config (dataclasses).

# If you add a configuration file under `configs/algorithm`, it will also be available as an option
# from the command-line, and be validated against the schema.
# todo: It might be nicer if we did this this `configs/algorithms` instead of here, no?
algorithm_store = store(group="algorithm")
algorithm_store(ExampleAlgorithm.HParams(), name="example_algo")
algorithm_store(ManualGradientsExample.HParams(), name="manual_optimization")
algorithm_store(builds(NoOp, populate_full_signature=False), name="no_op")
algorithm_store(JaxAlgorithm.HParams(), name="jax_algo")

algorithm_store.add_to_hydra_store()
from .example import ExampleAlgorithm

__all__ = [
"Algorithm",
"ExampleAlgorithm",
"ManualGradientsExample",
"JaxAlgorithm",
"JaxExample",
"NoOp",
]
76 changes: 35 additions & 41 deletions project/algorithms/callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
from __future__ import annotations

from collections.abc import Mapping
from logging import getLogger as get_logger
from pathlib import Path
from typing import Literal, override
from typing import Any, Generic, Literal, override

import torch
from lightning import Trainer
from lightning import LightningModule, Trainer
from lightning import pytorch as pl
from typing_extensions import Generic # noqa
from typing_extensions import TypeVar

from project.algorithms.algorithm import Algorithm, BatchType, StepOutputDict, StepOutputType
from project.utils.types import PhaseStr, PyTree
from project.utils.types import PyTree
from project.utils.utils import get_log_dir

logger = get_logger(__name__)

BatchType = TypeVar("BatchType", bound=PyTree[torch.Tensor], contravariant=True)
StepOutputType = TypeVar(
"StepOutputType",
bound=torch.Tensor | Mapping[str, Any] | None,
default=dict[str, torch.Tensor],
contravariant=True,
)

class Callback[BatchType: PyTree[torch.Tensor], StepOutputType: torch.Tensor | StepOutputDict](
pl.Callback
):

class Callback(pl.Callback, Generic[BatchType, StepOutputType]):
"""Adds a bit of typing info and shared functions to the PyTorch Lightning Callback class.
Adds the following typing information:
Expand All @@ -40,7 +46,7 @@ def __init__(self) -> None:
def setup(
self,
trainer: pl.Trainer,
pl_module: Algorithm[BatchType, StepOutputType],
pl_module: LightningModule,
# todo: "tune" is mentioned in the docstring, is it still used?
stage: Literal["fit", "validate", "test", "predict", "tune"],
) -> None:
Expand All @@ -49,10 +55,10 @@ def setup(
def on_shared_batch_start(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, StepOutputType],
pl_module: LightningModule,
batch: BatchType,
batch_index: int,
phase: PhaseStr,
phase: Literal["train", "val", "test"],
dataloader_idx: int | None = None,
):
"""Shared hook, called by `on_[train/validation/test]_batch_start`.
Expand All @@ -63,11 +69,11 @@ def on_shared_batch_start(
def on_shared_batch_end(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, StepOutputType],
pl_module: LightningModule,
outputs: StepOutputType,
batch: BatchType,
batch_index: int,
phase: PhaseStr,
phase: Literal["train", "val", "test"],
dataloader_idx: int | None = None,
):
"""Shared hook, called by `on_[train/validation/test]_batch_end`.
Expand All @@ -78,8 +84,8 @@ def on_shared_batch_end(
def on_shared_epoch_start(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, StepOutputType],
phase: PhaseStr,
pl_module: LightningModule,
phase: Literal["train", "val", "test"],
) -> None:
"""Shared hook, called by `on_[train/validation/test]_epoch_start`.
Expand All @@ -89,8 +95,8 @@ def on_shared_epoch_start(
def on_shared_epoch_end(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, StepOutputType],
phase: PhaseStr,
pl_module: LightningModule,
phase: Literal["train", "val", "test"],
) -> None:
"""Shared hook, called by `on_[train/validation/test]_epoch_end`.
Expand All @@ -101,7 +107,7 @@ def on_shared_epoch_end(
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, StepOutputType],
pl_module: LightningModule,
outputs: StepOutputType,
batch: BatchType,
batch_index: int,
Expand All @@ -126,7 +132,7 @@ def on_train_batch_end(
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, StepOutputType],
pl_module: LightningModule,
outputs: StepOutputType,
batch: BatchType,
batch_index: int,
Expand Down Expand Up @@ -154,7 +160,7 @@ def on_validation_batch_end(
def on_test_batch_end(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, StepOutputType],
pl_module: LightningModule,
outputs: StepOutputType,
batch: BatchType,
batch_index: int,
Expand Down Expand Up @@ -182,7 +188,7 @@ def on_test_batch_end(
def on_train_batch_start(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, StepOutputType],
pl_module: LightningModule,
batch: BatchType,
batch_index: int,
) -> None:
Expand All @@ -199,7 +205,7 @@ def on_train_batch_start(
def on_validation_batch_start(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, StepOutputType],
pl_module: LightningModule,
batch: BatchType,
batch_index: int,
dataloader_idx: int = 0,
Expand All @@ -218,7 +224,7 @@ def on_validation_batch_start(
def on_test_batch_start(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, StepOutputType],
pl_module: LightningModule,
batch: BatchType,
batch_index: int,
dataloader_idx: int = 0,
Expand All @@ -234,43 +240,31 @@ def on_test_batch_start(
)

@override
def on_train_epoch_start(
self, trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType]
) -> None:
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_train_epoch_start(trainer, pl_module)
self.on_shared_epoch_start(trainer, pl_module, phase="train")

@override
def on_validation_epoch_start(
self, trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType]
) -> None:
def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_validation_epoch_start(trainer, pl_module)
self.on_shared_epoch_start(trainer, pl_module, phase="val")

@override
def on_test_epoch_start(
self, trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType]
) -> None:
def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_test_epoch_start(trainer, pl_module)
self.on_shared_epoch_start(trainer, pl_module, phase="test")

@override
def on_train_epoch_end(
self, trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType]
) -> None:
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_train_epoch_end(trainer, pl_module)
self.on_shared_epoch_end(trainer, pl_module, phase="train")

@override
def on_validation_epoch_end(
self, trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType]
) -> None:
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_validation_epoch_end(trainer, pl_module)
self.on_shared_epoch_end(trainer, pl_module, phase="val")

@override
def on_test_epoch_end(
self, trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType]
) -> None:
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_test_epoch_end(trainer, pl_module)
self.on_shared_epoch_end(trainer, pl_module, phase="test")
23 changes: 11 additions & 12 deletions project/algorithms/callbacks/classification_metrics.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import warnings
from logging import getLogger as get_logger
from typing import NotRequired, Required, TypedDict, override
from typing import Literal, NotRequired, Required, TypedDict, override

import torch
import torchmetrics
from lightning import LightningModule, Trainer
from torch import Tensor
from torchmetrics.classification import MulticlassAccuracy

from project.algorithms.algorithm import Algorithm, BatchType
from project.algorithms.callbacks.callback import Callback
from project.utils.types import PhaseStr
from project.algorithms.callbacks.callback import BatchType, Callback
from project.utils.types.protocols import ClassificationDataModule

logger = get_logger(__name__)


class ClassificationOutputs(TypedDict, total=False):
"""The dictionary format that is minimally required to be returned from
`training/val/test_step` for classification algorithms."""
"""The outputs that should be minimally returned from the training/val/test_step of
classification LightningModules so that metrics can be added aumatically by the
`ClassificationMetricsCallback`."""

loss: NotRequired[torch.Tensor | float]
"""The loss at this step."""
Expand All @@ -31,14 +30,14 @@ class ClassificationOutputs(TypedDict, total=False):


class ClassificationMetricsCallback(Callback[BatchType, ClassificationOutputs]):
"""Callback that adds classification metrics to the pl module."""
"""Callback that adds classification metrics to a LightningModule."""

def __init__(self) -> None:
super().__init__()
self.disabled = False

@classmethod
def attach_to(cls, algorithm: Algorithm, num_classes: int):
def attach_to(cls, algorithm: LightningModule, num_classes: int):
callback = cls()
callback.add_metrics_to(algorithm, num_classes=num_classes)
return callback
Expand Down Expand Up @@ -84,8 +83,8 @@ def _get_metric(pl_module: LightningModule, name: str):
def setup(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, ClassificationOutputs],
stage: PhaseStr,
pl_module: LightningModule,
stage: Literal["fit", "validate", "test", "predict", "tune"],
) -> None:
if self.disabled:
return
Expand All @@ -108,11 +107,11 @@ def setup(
def on_shared_batch_end(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, ClassificationOutputs],
pl_module: LightningModule,
outputs: ClassificationOutputs,
batch: BatchType,
batch_index: int,
phase: PhaseStr,
phase: Literal["train", "val", "test"],
dataloader_idx: int | None = None,
):
if self.disabled:
Expand Down
Loading

0 comments on commit 0d35a38

Please sign in to comment.