Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docstrings revamp. #589

Merged
merged 1 commit into from
Jul 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,33 @@ repos:

# python docstring formatting
- repo: https://github.com/myint/docformatter
rev: v1.5.1
rev: v1.7.4
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99]
args:
[
--in-place,
--wrap-summaries=99,
--wrap-descriptions=99,
--style=sphinx,
--black,
]

# python docstring coverage checking
- repo: https://github.com/econchick/interrogate
rev: 1.5.0 # or master if you're bold
hooks:
- id: interrogate
args:
[
--verbose,
--fail-under=80,
--ignore-init-module,
--ignore-init-method,
--ignore-module,
--ignore-nested-functions,
-vv,
]

# python check (PEP8), programming errors and code complexity
- repo: https://github.com/PyCQA/flake8
Expand All @@ -53,10 +76,11 @@ repos:
args:
[
"--extend-ignore",
"E203,E402,E501,F401,F841",
"E203,E402,E501,F401,F841,RST2,RST301",
"--exclude",
"logs/*,data/*",
]
additional_dependencies: [flake8-rst-docstrings==0.3.0]

# python security linter
- repo: https://github.com/PyCQA/bandit
Expand Down
117 changes: 89 additions & 28 deletions src/data/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,42 @@


class MNISTDataModule(LightningDataModule):
"""Example of LightningDataModule for MNIST dataset.
"""`LightningDataModule` for the MNIST dataset.
A DataModule implements 6 key methods:
The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples.
It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a
fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box
while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing
technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of
mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.
A `LightningDataModule` implements 7 key methods:
```python
def prepare_data(self):
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
# download data, pre-process, split, save to disk, etc...
# Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
# Download data, pre-process, split, save to disk, etc...
def setup(self, stage):
# things to do on every process in DDP
# load data, set variables, etc...
# Things to do on every process in DDP.
# Load data, set variables, etc...
def train_dataloader(self):
# return train dataloader
# return train dataloader
def val_dataloader(self):
# return validation dataloader
# return validation dataloader
def test_dataloader(self):
# return test dataloader
def teardown(self):
# called on every process in DDP
# clean up after fit or test
# return test dataloader
def predict_dataloader(self):
# return predict dataloader
def teardown(self, stage):
# Called on every process in DDP.
# Clean up after fit or test.
```
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Expand All @@ -41,7 +59,15 @@ def __init__(
batch_size: int = 64,
num_workers: int = 0,
pin_memory: bool = False,
):
) -> None:
"""Initialize a `MNISTDataModule`.
:param data_dir: The data directory. Defaults to `"data/"`.
:param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.
:param batch_size: The batch size. Defaults to `64`.
:param num_workers: The number of workers. Defaults to `0`.
:param pin_memory: Whether to pin memory. Defaults to `False`.
"""
super().__init__()

# this line allows to access init params with 'self.hparams' attribute
Expand All @@ -58,22 +84,33 @@ def __init__(
self.data_test: Optional[Dataset] = None

@property
def num_classes(self):
def num_classes(self) -> int:
"""Get the number of classes.
:return: The number of MNIST classes (10).
"""
return 10

def prepare_data(self):
"""Download data if needed.
def prepare_data(self) -> None:
"""Download data if needed. Lightning ensures that `self.prepare_data()` is called only
within a single process on CPU, so you can safely add your downloading logic within. In
case of multi-node training, the execution of this hook depends upon
`self.prepare_data_per_node()`.
Do not use it to assign state (self.x = y).
"""
MNIST(self.hparams.data_dir, train=True, download=True)
MNIST(self.hparams.data_dir, train=False, download=True)

def setup(self, stage: Optional[str] = None):
def setup(self, stage: Optional[str] = None) -> None:
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
careful not to execute things like random split twice!
This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
`trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
`self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
`self.setup()` once the data is prepared and available for use.
:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
"""
# load and split datasets only if not loaded already
if not self.data_train and not self.data_val and not self.data_test:
Expand All @@ -86,7 +123,11 @@ def setup(self, stage: Optional[str] = None):
generator=torch.Generator().manual_seed(42),
)

def train_dataloader(self):
def train_dataloader(self) -> DataLoader[Any]:
"""Create and return the train dataloader.
:return: The train dataloader.
"""
return DataLoader(
dataset=self.data_train,
batch_size=self.hparams.batch_size,
Expand All @@ -95,7 +136,11 @@ def train_dataloader(self):
shuffle=True,
)

def val_dataloader(self):
def val_dataloader(self) -> DataLoader[Any]:
"""Create and return the validation dataloader.
:return: The validation dataloader.
"""
return DataLoader(
dataset=self.data_val,
batch_size=self.hparams.batch_size,
Expand All @@ -104,7 +149,11 @@ def val_dataloader(self):
shuffle=False,
)

def test_dataloader(self):
def test_dataloader(self) -> DataLoader[Any]:
"""Create and return the test dataloader.
:return: The test dataloader.
"""
return DataLoader(
dataset=self.data_test,
batch_size=self.hparams.batch_size,
Expand All @@ -113,16 +162,28 @@ def test_dataloader(self):
shuffle=False,
)

def teardown(self, stage: Optional[str] = None):
"""Clean up after fit or test."""
def teardown(self, stage: Optional[str] = None) -> None:
"""Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
`trainer.test()`, and `trainer.predict()`.
:param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
Defaults to ``None``.
"""
pass

def state_dict(self):
"""Extra things to save to checkpoint."""
def state_dict(self) -> Dict[Any, Any]:
"""Called when saving a checkpoint. Implement to generate and save the datamodule state.
:return: A dictionary containing the datamodule state that you want to save.
"""
return {}

def load_state_dict(self, state_dict: Dict[str, Any]):
"""Things to do when loading checkpoint."""
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Called when loading a checkpoint. Implement to reload datamodule state given datamodule
`state_dict()`.
:param state_dict: The datamodule state returned by `self.state_dict()`.
"""
pass


Expand Down
16 changes: 8 additions & 8 deletions src/eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import Any, Dict, List, Tuple

import hydra
import pyrootutils
Expand Down Expand Up @@ -30,19 +30,15 @@


@utils.task_wrapper
def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Evaluates given checkpoint on a datamodule testset.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
Args:
cfg (DictConfig): Configuration composed by Hydra.
Returns:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
:param cfg: DictConfig configuration composed by Hydra.
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
"""

assert cfg.ckpt_path

log.info(f"Instantiating datamodule <{cfg.data._target_}>")
Expand Down Expand Up @@ -82,6 +78,10 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:

@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
def main(cfg: DictConfig) -> None:
"""Main entry point for evaluation.
:param cfg: DictConfig configuration composed by Hydra.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils.extras(cfg)
Expand Down
20 changes: 18 additions & 2 deletions src/models/components/simple_dense_net.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
import torch
from torch import nn


class SimpleDenseNet(nn.Module):
"""A simple fully-connected neural net for computing predictions."""

def __init__(
self,
input_size: int = 784,
lin1_size: int = 256,
lin2_size: int = 256,
lin3_size: int = 256,
output_size: int = 10,
):
) -> None:
"""Initialize a `SimpleDenseNet` module.
:param input_size: The number of input features.
:param lin1_size: The number of output features of the first linear layer.
:param lin2_size: The number of output features of the second linear layer.
:param lin3_size: The number of output features of the third linear layer.
:param output_size: The number of output features of the final linear layer.
"""
super().__init__()

self.model = nn.Sequential(
Expand All @@ -25,7 +36,12 @@ def __init__(
nn.Linear(lin3_size, output_size),
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Perform a single forward pass through the network.
:param x: The input tensor.
:return: A tensor of predictions.
"""
batch_size, channels, width, height = x.size()

# (batch, 1, width, height) -> (batch, 1*width*height)
Expand Down
Loading