PyTorch implementation of the reversible residual network.
The main requirement ist obviously PyTorch. CUDA is strongly recommended.
The training script requires tqdm for the progress bar.
The unittests require the TestCase implemented by the PyTorch project. The module can be downloaded here.
The revnet models in this project tend to have exploding gradients. To counteract this, I used gradient norm clipping. For the experiments below you would call the following command:
python train_cifar.py --model revnet38 --clip 0.25
Model | Accuracy | Memory Usage | Params |
---|---|---|---|
resnet32 | 92.02% | 1271 MB | 0.47 M |
revnet38 | 91.98% | 660 MB | 0.47 M |