In this project, we provide a strong template for a PyTorch project. The purpose of this repository is to provide an example (and strong utils on the way) for a deep learning project using PyTorch.
To run tensorboard simply run in the project directory tensorboard --logdir logs/tensorboard
We use the Hydra framework for configuration. It allows us to easily read configuration files and to do hyper-parameter tuning.
The configuration file is stored under config/config.yaml
and has many parameters. At the beginning of each experiment,
the configuration file is validated with a schema. The schema is stored at utils/config_schema.py
.
In case we want to add/remove a parameter from the configuration file we need to:
- Change the YAML file
- Change the schema at
utils/config_schema.py
- Verify that we don't break anything in
TrainParams
To run hyperparameter tuning just follow the instructions here.
The model package should store all the models we use. You can take a look at models/base_model.py
in MyModel
for example.
The networks package should hold layers such as modified sequential. For example, in the nets
package, you can see
FCNet
which is an easy implementation for a fully connected network with weight normalization, dropout and easy way to
add hidden layers with different values of hidden neurons.
The train function gathers all the training logic. Its structure is built for an easy change. For example, you can change the optimizer and more.
Moreover, during the training and evaluation stage, the logger reports relevant metrics to stdout, file, and TensorBoard.
New dataset creation built from three stages:
- Set variables - set the relevant inputs to self.
- Load features - pick the best way to load features (memory/disk) and implement the load stage under
self._get_features()
- Create a list of entries - implement
self._get_entries()
Then, in __getitem__
you only need to retrieve samples from the list you created in stage #3.
The logger class is logging messages to:
- Tensorboard
- Stdout
- Files
Each experiment has a directory under logs/
(configurable).
By default, the best epoch and all the output will be saved there.
In case you add/change a type you can add it to utils/types.py
To add a new dataset:
- Add relevant variables to the constructor
- Implement
self._get_features
- Implement
self._get_entries
- Implement
self.__getitem__
To change a variable in the configuration file:
- Change it in the
config.yaml
file - Update the schema under
utils/config_schema.py
- Update
TrainParams
inutils/train_utils.py
- Hydra for a great configuration framework.
- bottom-up-attention-vqa for great modified layers.