Skip to content

Commit

Permalink
Cleanup code and improve docs (#75)
Browse files Browse the repository at this point in the history
* Mount the UV cache in devcontainer->faster builds

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Move the "references" section

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add more text in the Jax RL Example

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Simplify the paragraph headers a bit

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Revert "Move the "references" section"

This reverts commit 26ed330.

* Cleanup `docs/generate_reference_docs.py`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Various small docstring changes

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Reorganize docs a bit

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Make the auto-schema plugin more portable

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issue in import_object

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Move / rename doc files

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* More helpful error msg in instance_attr resolver

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix [jax] in example link

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add a bit of text

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix bug in auto_schema.py

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Update the cluster sweep example config

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add a "--watch" flag to auto-schema util

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix outdated docstring

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Simplify `main.py`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix broken test

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix broken test

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Rename `overrides`->`command_line_overrides`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove outdated auto_schema_test.py module

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Also ignore regression files folder

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove buggy test, add note and todo

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove outdated comment

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Try to fix weird testing bug

Sharing state between tests is causing issues!

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix bug when `trainer.logger` is None

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* "fix" bug in config_test.py

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Update config_test.py

* Update config_test.py

* Update config_test.py

* Update main_test.py

* Update main_test.py

* Update main_test.py

* Update profiling_test.py

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice authored Oct 30, 2024
1 parent 7b3fea2 commit 6400ed3
Show file tree
Hide file tree
Showing 34 changed files with 450 additions and 1,455 deletions.
40 changes: 20 additions & 20 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
},
"ghcr.io/devcontainers-contrib/features/apt-get-packages": {
"packages": [
"vim",
"sshfs"
"vim"
]
},
"ghcr.io/va-h/devcontainers-features/uv:1": {}
Expand Down Expand Up @@ -52,7 +51,8 @@
".venv": true,
".pytest_cache": true,
".benchmarks": true,
".ruff_cache": true
".ruff_cache": true,
".regression_files": true
},
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
Expand All @@ -75,46 +75,46 @@
]
}
},
// create the cache dir on the host machine if it doesn exist yet so the mount below
// doesn't fail.
"initializeCommand": {
"create fake SLURM_TMPDIR": "mkdir -p ${SLURM_TMPDIR:-/tmp/slurm_tmpdir}", // this is fine on compute nodes
"create ssh cache dir": "mkdir -p ~/.cache/ssh", // Used to store the ssh sockets (from ControlPath directive).
"create uv cache dir": "mkdir -p ~/.cache/uv" // Used to store the ssh sockets (from ControlPath directive).
},
"containerEnv": {
"SCRATCH": "/home/vscode/scratch",
"SLURM_TMPDIR": "/tmp",
"NETWORK_DIR": "/network",
"UV_LINK_MODE": "symlink"
"UV_LINK_MODE": "symlink",
"UV_CACHE_DIR": "/home/vscode/.uv_cache"
},
"mounts": [
// https://code.visualstudio.com/remote/advancedcontainers/add-local-file-mount
// Mount a directory which will contain the pdm installation cache (shared with the host machine).
// This will use $SCRATCH/.cache/pdm, otherwise
// Mount a "$SCRATCH" directory in the host to ~/scratch in the container.
"source=${localEnv:SCRATCH},target=/home/vscode/scratch,type=bind,consistency=cached",
// Mount a /network to match the /network directory on the host.
// FIXME: This assumes that either the NETWORK_DIR environment variable is set on the host, or
// that the /network directory exists.
"source=${localEnv:NETWORK_DIR:/network},target=/network,type=bind,readonly",
// Mount $SLURM_TMPDIR on the host machine to /tmp/slurm_tmpdir in the container.
// note: there's also a SLURM_TMPDIR env variable set to /tmp/slurm_tmpdir in the container.
// NOTE: this assumes that either $SLURM_TMPDIR is set on the host machine (e.g. a compute node)
// or that `/tmp/slurm_tmpdir` exists on the host machine.
"source=${localEnv:SLURM_TMPDIR:/tmp/slurm_tmpdir},target=/tmp,type=bind,consistency=cached",
// Mount the ssh directory on the host machine to the container.
"source=${localEnv:HOME}/.ssh,target=/home/vscode/.ssh,type=bind,readonly"
// Mount the ssh directory on the host machine to the container so we can use SSH in the
// same way as on the local machine.
"source=${localEnv:HOME}/.ssh,target=/home/vscode/.ssh,type=bind,readonly",
// Mount the ssh cache directory on the host machine to the container.
"source=${localEnv:HOME}/.cache/uv,target=/home/vscode/.uv_cache,type=bind,consistency=cached"
],
"runArgs": [
"--gpus",
"all",
"--gpus", // COMMENT OUT IF YOUR LAPTOP DOES NOT HAVE A GPU!
"all", // COMMENT OUT IF YOUR LAPTOP DOES NOT HAVE A GPU!
"--ipc=host"
],
// create the pdm cache dir on the host machine if it doesn exist yet so the mount above
// doesn't fail.
"initializeCommand": {
"create fake SLURM_TMPDIR": "mkdir -p ${SLURM_TMPDIR:-/tmp/slurm_tmpdir}", // this is fine on compute nodes
"create ssh cache dir": "mkdir -p ~/.cache/ssh"
},
"onCreateCommand": {
"pre-commit": "pre-commit install --install-hooks"
},
"updateContentCommand": {
"Sync dependencies": "uv sync --frozen"
"Sync dependencies": "uv sync --locked"
},
// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand": {
Expand Down
10 changes: 5 additions & 5 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
* [Home](index.md)
* [Intro](intro.md)
* Features 🔥
* [Features 🔥](features/index.md)
* [Magic Config Schemas](features/auto_schema.md)
* [Jax and Torch support with Lightning ⚡](features/jax.md)
* [Launching Jobs on Remote Clusters](features/remote_slurm_launcher.md)
* [Thorough automated testing on SLURM clusters](features/testing.md)
* features/*.md
* Reference 🤓
* reference/*
* Examples 🧪
* [Image Classification (⚡)](examples/supervised_learning.md)
* [Image Classification ([jax](+⚡)](examples/jax_sl_example.md)
* [Image Classification (jax+⚡)](examples/jax_sl_example.md)
* [NLP (🤗+⚡)](examples/nlp.md)
* [RL (jax)](examples/jax_rl_example.md)
* [Running sweeps](examples/sweeps.md)
* [Profiling your code📎](examples/profiling.md)
* [Related projects](related.md)
* Reference 🤓
* reference/*
* [Learning Resources](resources.md)
* [Getting Help](help.md)
* [Contributing](contributing.md)
30 changes: 26 additions & 4 deletions docs/examples/jax_rl_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,47 @@ additional_python_references:
---


# Reinforcement Learning (Jax)
# Reinforcement Learning in Jax


This example follows the same structure as the other examples:
- An "algorithm" (in this case `JaxRLExample`) is trained with a "trainer";

- An "algorithm" (in this case `JaxRLExample`) is trained with a "trainer" (`JaxTrainer`);


However, there are some very important differences:
- There is not "datamodule".

- There is no "datamodule". The algorithm accepts an Environment (`gymnax.Environment`) as input.
- The "Trainer" is a `JaxTrainer`, instead of a `lightning.Trainer`.
- The full training loop is written in Jax;
- Some (but not all) PyTorch-Lightning callbacks can still be used with the JaxTrainer;
- The `JaxRLExample` class is an algorithm based on rejax.PPO.




## JaxRLExample

The `JaxRLExample` class is a
The `JaxRLExample` is based on [rejax.PPO](https://github.com/keraJLi/rejax/blob/main/rejax/algos/ppo.py).
It follows the structure of a `JaxModule`, and is trained with a `JaxTrainer`.


??? note "Click to show the code for JaxRLExample"
{{ inline('project.algorithms.jax_rl_example.JaxRLExample', 4) }}


## JaxModule

The `JaxModule` class is made to look a bit like the `lightning.LightningModule` class:

{{ inline('project.trainers.jax_trainer.JaxModule', 0) }}


## JaxTrainer

The `JaxTrainer` follows a roughly similar structure as the `lightning.Trainer`:
- `JaxTrainer.fit` is called with a `JaxModule` to train the algorithm.


??? note "Click to show the code for JaxTrainer"
{{ inline('project.trainers.jax_trainer.JaxTrainer', 4) }}
4 changes: 2 additions & 2 deletions docs/examples/jax_sl_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ In this example, the loss function and optimizers are in PyTorch, while the netw
The loss that is returned in the training step is used by Lightning in the usual way. The backward
pass uses Jax to calculate the gradients, and the weights are updated by a PyTorch optimizer.

!!! note
!!! info
You could also very well do both the forward **and** backward passes in Jax! To do this, [use the 'manual optimization' mode of PyTorch-Lightning](https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html) and perform the parameter updates yourself. For the rest of Lightning to work, just make sure to store the parameters as torch.nn.Parameters. An example of how to do this will be added shortly.



!!! note "What about end-to-end training in Jax?"
!!! question "What about end-to-end training in Jax?"

See the [Jax RL Example](../examples/jax_rl_example.md)! :smile:

Expand Down
5 changes: 3 additions & 2 deletions docs/examples/sweeps.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ python project/main.py experiment=local_sweep_example
??? note "Click to show the yaml config file"
{{inline('project/configs/experiment/cluster_sweep_example.yaml', 4)}}

You can use it like so:
Here's how you can easily launch a sweep remotely on the Mila cluster.
If you are already on a slurm cluster, use the `"cluster=current"` config.

```console
python project/main.py experiment=cluster_sweep_example
python project/main.py experiment=cluster_sweep_example cluster=mila
```
17 changes: 17 additions & 0 deletions docs/features/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

# Features unique to this project template

Here are some cool features that are unique to this particular template:


- Support for both Jax and Torch with PyTorch-Lightning (See the [Jax example](jax.md))
- Your Hydra configs will have an [Auto-Generated YAML schemas](auto_schema.md) 🔥
- A comprehensive suite of automated tests for new algorithms, datasets and networks
- 🤖 [Thoroughly tested on the Mila directly with GitHub CI](testing.md#automated-testing-on-slurm-clusters-with-github-ci)
- Automated testing on the DRAC clusters will also be added soon.
- Easy development inside a [devcontainer with VsCode]
- Tailor-made for ML researchers that run their jobs on SLURM clusters (with default configurations for the [Mila](https://docs.mila.quebec) and [DRAC](https://docs.alliancecan.ca) clusters.)
- Rich typing of all parts of the source code

This template is aimed for ML researchers that run their jobs on SLURM clusters.
The target audience is researchers and students at [Mila](https://mila.quebec). This template should still be useful for others outside of Mila that use PyTorch-Lightning and Hydra.
3 changes: 2 additions & 1 deletion docs/features/jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ The [lightning.Trainer][lightning.pytorch.trainer.trainer.Trainer] will not be a

## End-to-end training in Jax: the `JaxTrainer`

The `JaxTrainer`, used in the [Jax RL Example](../examples/jax_rl_example.md), follows a similar structure as the lightning Trainer. However, instead of training LightningModules, it trains `JaxModule`s.
The `JaxTrainer`, used in the [Jax RL Example](../examples/jax_rl_example.md), follows a similar structure as the lightning Trainer. However, instead of training LightningModules, it trains `JaxModule`s, which are a simplified, jax-based look-alike of `lightning.LightningModule`s.


The "algorithm" needs to match the `JaxModule` protocol:
- `JaxModule.training_step`: train using a batch of data
6 changes: 3 additions & 3 deletions docs/features/remote_slurm_launcher.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,18 @@ This assumes that you've already setup SSH access to the clusters (for example u
python project/main.py experiment=example resources=gpu cluster=mila
```

### Local machine -> DRAC cluster (narval)
### Local machine -> DRAC (narval)

```bash
python project/main.py experiment=example resources=gpu cluster=narval
```


### Mila -> DRAC cluster (narval)
### Mila -> DRAC (narval)

This assumes that you've already setup SSH access from `mila` to the DRAC clusters.

Note that command is about the same as [above](#local-machine---drac-cluster-narval)
Note that command is exactly the same as above.

```bash
python project/main.py experiment=example resources=gpu cluster=narval
Expand Down
24 changes: 20 additions & 4 deletions docs/generate_reference_docs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#!/usr/bin/env python
# based on https://github.com/mkdocstrings/mkdocstrings/blob/5802b1ef5ad9bf6077974f777bd55f32ce2bc219/docs/gen_doc_stubs.py#L25
"""Script used to generate the reference docs for the project from the source code.
Based on
https://github.com/mkdocstrings/mkdocstrings/blob/5802b1ef5ad9bf6077974f777bd55f32ce2bc219/docs/gen_doc_stubs.py#L25
"""

import os
import textwrap
from logging import getLogger as get_logger
from pathlib import Path

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
logger = get_logger(__name__)


Expand Down Expand Up @@ -38,7 +40,21 @@ def main():

with mkdocs_gen_files.open(full_doc_path, "w") as fd:
ident = ".".join(parts)
fd.write(f"::: {ident}\n")
fd.write(
textwrap.dedent(
# f"""\
# ---
# additional_python_references:
# - {ident}
# ---
# ::: {ident}
# """
f"""\
::: {ident}
"""
)
)
# fd.write(f"::: {ident}\n")

mkdocs_gen_files.set_edit_path(full_doc_path, path.relative_to(root))

Expand Down
6 changes: 6 additions & 0 deletions docs/help.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,10 @@

## FAQ



## How to get help

- Make an Issue on GitHub
- Reach out via Slack (if you're a researcher at Mila)
- Reach out via email
17 changes: 3 additions & 14 deletions docs/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,10 @@ Other good reads:
- [https://12factor.net/](https://12factor.net/)
- [https://github.com/ashleve/lightning-hydra-template/tree/main?tab=readme-ov-file#main-ideas](https://github.com/ashleve/lightning-hydra-template/tree/main?tab=readme-ov-file#main-ideas)

## Why should you use *this* template (instead of another)?
## Why use *this* template?

You are welcome (and encouraged) to use other similar templates which, at the time of writing this, have significantly better documentation. However, there are several advantages to using this particular template:
- [Cool, unique features that can *only* be found here (for now)!](features/index.md)

- ❗Support for both Jax and Torch with PyTorch-Lightning (See the [Jax example](features/jax.md))❗
- Your Hydra configs will have an [Auto-Generated YAML schemas](features/auto_schema.md) 🔥
- A comprehensive suite of automated tests for new algorithms, datasets and networks
- 🤖 [Thoroughly tested on the Mila directly with GitHub CI](features/testing.md#automated-testing-on-slurm-clusters-with-github-ci)
- Automated testing on the DRAC clusters will also be added soon.
- Easy development inside a devcontainer with VsCode
- Tailor-made for ML researchers that run their jobs on SLURM clusters (with default configurations for the [Mila](https://docs.mila.quebec) and [DRAC](https://docs.alliancecan.ca) clusters.)
- Rich typing of all parts of the source code

This template is aimed for ML researchers that run their jobs on SLURM clusters.
The target audience is researchers and students at [Mila](https://mila.quebec). This template should still be useful for others outside of Mila that use PyTorch-Lightning and Hydra.

## Project layout

Expand All @@ -38,7 +27,7 @@ project/
algorithms/ # learning algorithms
datamodules/ # datasets, processing and loading
networks/ # Neural networks used by algorithms
configs/ # configuration files
configs/ # Hydra configuration files
docs/ # documentation
conftest.py # Test fixtures and utilities
```
Expand Down
4 changes: 1 addition & 3 deletions docs/profiling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
algorithm_config,
algorithm_network_config,
command_line_arguments,
command_line_overrides,
datamodule_config,
devices,
experiment_dictconfig,
num_devices_to_use,
overrides,
)
from project.experiment import setup_experiment
from project.utils.hydra_utils import resolve_dictconfig
Expand Down
21 changes: 0 additions & 21 deletions docs/related.md

This file was deleted.

33 changes: 33 additions & 0 deletions docs/resources.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Related projects and resources

## [Hydra docs](https://hydra.cc)


## Other project templates

There are other project templates out there, that often have better documentation.
If you need an introduction to Hydra, or Lightning, or good software development practices, these might have better guides and documentation for you.

Here are some we'd recommend:


### [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template)

- How it works: https://github.com/gorodnitskiy/yet-another-lightning-hydra-template/tree/main?tab=readme-ov-file#workflow---how-it-works

For everything that has to do with Hydra and PyTorch-Lightning, their documentation also applies directly to this template. In order to avoid copying their documentation, we recommend you take a look at their nice readme.


### [yet-another-lightning-hydra-template](https://github.com/gorodnitskiy/yet-another-lightning-hydra-template)

- Excellent template. based on the lightning-hydra-template. Great documentation, which is referenced extensively in this project.
- - Has a **great** Readme with lots of information
- - Is really well organized
- - doesn't support Jax
- - doesn't have a devcontainer
- Great blog: https://hackernoon.com/yet-another-lightning-hydra-template-for-ml-experiments

### [cookiecutter-data-science](https://github.com/drivendataorg/cookiecutter-data-science)

- Awesome library for data science.
- Related projects: https://github.com/drivendataorg/cookiecutter-data-science/blob/master/docs/docs/related.md#links-to-related-projects-and-references
Loading

0 comments on commit 6400ed3

Please sign in to comment.