Skip to content

Commit

Permalink
Add docstrings in Callback class
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Jun 25, 2024
1 parent e1ffefb commit 4633846
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
3 changes: 1 addition & 2 deletions project/algorithms/algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,6 @@ class AllParamsShouldHaveGradients(GetGradientsCallback):
def __init__(self, exceptions: Sequence[str] = ()) -> None:
super().__init__()
self.exceptions = exceptions

self.gradients: dict[str, Tensor] = {}

def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
Expand All @@ -655,7 +654,7 @@ def on_after_backward(self, trainer: Trainer, pl_module: LightningModule) -> Non
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
pl_module: Algorithm,
outputs: STEP_OUTPUT,
batch: Any,
batch_index: int,
Expand Down
37 changes: 32 additions & 5 deletions project/algorithms/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,18 @@
class Callback[BatchType: PyTree[torch.Tensor], StepOutputType: torch.Tensor | StepOutputDict](
pl.Callback
):
"""Adds a bit of typing info and shared functions to the PyTorch Lightning Callback class."""
"""Adds a bit of typing info and shared functions to the PyTorch Lightning Callback class.
Adds the following typing information:
- The type of inputs that the algorithm takes
- The type of outputs that are returned by the algorithm's `[training/validation/test]_step` methods.
Adds the following methods:
- `on_shared_batch_start`: called by `on_[train/validation/test]_batch_start`
- `on_shared_batch_end`: called by `on_[train/validation/test]_batch_end`
- `on_shared_epoch_start`: called by `on_[train/validation/test]_epoch_start`
- `on_shared_epoch_end`: called by `on_[train/validation/test]_epoch_end`
"""

def __init__(self) -> None:
super().__init__()
Expand All @@ -43,7 +54,11 @@ def on_shared_batch_start(
batch_index: int,
phase: PhaseStr,
dataloader_idx: int | None = None,
): ...
):
"""Shared hook, called by `on_[train/validation/test]_batch_start`.
Use this if you want to do something at the start of batches in more than one phase.
"""

def on_shared_batch_end(
self,
Expand All @@ -54,21 +69,33 @@ def on_shared_batch_end(
batch_index: int,
phase: PhaseStr,
dataloader_idx: int | None = None,
): ...
):
"""Shared hook, called by `on_[train/validation/test]_batch_end`.
Use this if you want to do something at the end of batches in more than one phase.
"""

def on_shared_epoch_start(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, StepOutputType],
phase: PhaseStr,
) -> None: ...
) -> None:
"""Shared hook, called by `on_[train/validation/test]_epoch_start`.
Use this if you want to do something at the start of epochs in more than one phase.
"""

def on_shared_epoch_end(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, StepOutputType],
phase: PhaseStr,
) -> None: ...
) -> None:
"""Shared hook, called by `on_[train/validation/test]_epoch_end`.
Use this if you want to do something at the end of epochs in more than one phase.
"""

@override
def on_train_batch_end(
Expand Down
4 changes: 4 additions & 0 deletions project/algorithms/callbacks/samples_per_second.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from typing import override

from lightning import LightningModule, Trainer
from torch import Tensor
Expand All @@ -16,6 +17,7 @@ def __init__(self):
self.last_update_time: dict[int, float | None] = {}
self.num_optimizers: int | None = None

@override
def on_shared_epoch_start(
self,
trainer: Trainer,
Expand All @@ -31,6 +33,7 @@ def on_shared_epoch_start(
else:
self.num_optimizers = len(optimizer_or_optimizers)

@override
def on_shared_batch_end(
self,
trainer: Trainer,
Expand Down Expand Up @@ -66,6 +69,7 @@ def on_shared_batch_end(
# todo: support other kinds of batches
self.last_step_times[phase] = now

@override
def on_before_optimizer_step(
self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer, opt_idx: int = 0
) -> None:
Expand Down

0 comments on commit 4633846

Please sign in to comment.