This repository contains PyTorch implementations for SAViT. For details see SAViT: Structure-Aware Vision Transformer Pruning via Collaborative Optimization
SAViT is a structured pruning method to prune the following comopnents in vision trnasofmrer and reaches 2.05x speedup with only 0.2% accuracy loss.
Step 1: Create a new conda environment:
conda create -n savit python=3.8
conda activate savit
Step 2: Install relevant packages
cd /path to deit_savit
pip install -r requirements.txt
Download and extract ImageNet train and val images from http://image-net.org/.
The directory structure is the standard layout for the torchvision datasets.ImageFolder
, and the training and validation data is expected to be in the train/
folder and val
folder respectively:
/path to imagenet
train/
class1/
img1.jpeg
class2/
img2.jpeg
val/
class1/
img3.jpeg
class/2
img4.jpeg
The scripts folder contrains all the bash commands to replicate the main results in our paper:
Running following command for pruning deit-base will give you a pruned model with mask indicates which neuron or head should be removed, corresponding to the results in Table 2.
Prune deit-base 50% FLOPs
python main.py \
--finetune=/path to deit_base checkpoint \
--batch-size=32 \
--num_workers=16 \
--data-path=/path to ImageNet \
--model=deit_base_patch16_224 \
--pruning_per_iteration=100 \
--pruning_feed_percent=0.1 \
--pruning_method=2 \
--pruning_layers=3 \
--pruning_flops_percentage=0.50 \
--pruning_flops_threshold=0.0001 \
--need_hessian \
--finetune_op=2 \
--epochs=1 \
--output_dir=/path to output
You can change FLOPs reduction or model as you wish. If you have already get pruning importance metric, you can simply load them by setting:
--pruning_pickle_from=/path to importance
For help information of the arguments please see main.py.
For deit-base after pruning, we need to retrain the pruned model to recover their performance. Run following command for fine-tuning on ImageNet on a single node with 8 gpus with a total batch size of 1024 for 300 epochs.
Fine-tune pruned DeiT-base
GPU_NUM=8
output_dir=/path to output
ck_dir=$output_dir/checkpoint.pth
# check if checkpoint exists
if [ -e $ck_dir ];then
CMD="--resume=${ck_dir}"
else
CMD="--resume="
fi
python -m torch.distributed.launch --nproc_per_node=${GPU_NUM} --use_env main_deploy.py \
--dist-eval \
$CMD \
--masked_model=/path to pruned_model in previous step prune \
--teacher-path=/path to regnet model as deit paper\
--batch-size=128\
--num_workers=16 \
--data-path=/path to ImageNet \
--model=deit_base_patch16_224_deploy \
--pruning_flops_percentage=0 \
--finetune_op=1 \
--epochs=300 \
--warmup-epochs=0 \
--cooldown-epochs=0 \
--output_dir=$output_dir
Note: fine-tuning is runing the main_deploy.py, which generates a smaller model according to the pruning mask in pruned model from previous step prune to accelerate fine-tuning.
To ease reproduction of our results we provide prune and finetune logs folder. The slight difference between results in logs and results in our paper comes from PyTorch version.
Our repository is built on the Deit, Taylor_pruning, Timm and flops-counter, we sincerely thank the authors for their nicely organized code!
This repository is released under the Apache 2.0 license as found in the LICENSE file.
If you find this repository helpful, please cite:
@article{zheng2022savit,
title={SAViT: Structure-Aware Vision Transformer Pruning via Collaborative Optimization},
author={Zheng, Chuanyang and Zhang, Kai and Yang, Zhi and Tan, Wenming and Xiao, Jun and Ren, Ye and Pu, Shiliang and others},
journal={Advances in Neural Information Processing Systems},
volume={35},
pages={9010--9023},
year={2022}
}