Skip to content

Implementation of the reversible residual network in pytorch

License

Notifications You must be signed in to change notification settings

tbung/pytorch-revnet

Repository files navigation

revnet

PyTorch implementation of the reversible residual network.

Requirements

The main requirement ist obviously PyTorch. CUDA is strongly recommended.

The training script requires tqdm for the progress bar.

The unittests require the TestCase implemented by the PyTorch project. The module can be downloaded here.

Note

The revnet models in this project tend to have exploding gradients. To counteract this, I used gradient norm clipping. For the experiments below you would call the following command:

python train_cifar.py --model revnet38 --clip 0.25

Results

CIFAR-10

Model Accuracy Memory Usage Params
resnet32 92.02% 1271 MB 0.47 M
revnet38 91.98% 660 MB 0.47 M

About

Implementation of the reversible residual network in pytorch

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages