Pytorch implementation code for [Neural Architecture Search for Spiking Neural Networks], ECCV 2022 (https://arxiv.org/abs/2201.10355)
Spiking Neural Networks (SNNs) have gained huge attention as a potential energy-efficient alternative to conventional Artificial Neural Networks (ANNs) due to their inherent high-sparsity activation. However, most prior SNN methods use ANN-like architectures (e.g., VGG-Net or ResNet), which could provide sub-optimal performance for temporal sequence processing of binary information in SNNs. To address this, in this paper, we introduce a novel Neural Architecture Search (NAS) approach for finding better SNN architectures. Inspired by recent NAS approaches that find the optimal architecture from activation patterns at initialization, we select the architecture that can represent diverse spike activation patterns across different data samples without training. Moreover, to further leverage the temporal information among the spikes, we search for feed forward connections as well as backward connections (i.e., temporal feedback connections) between layers. Interestingly, SNASNet found by our search algorithm achieves higher performance with backward connections, demonstrating the importance of designing SNN architecture for suitably using temporal information. We conduct extensive experiments on three image recognition benchmarks where we show that SNASNet achieves state-of-the-art performance with significantly lower timesteps (5 timesteps).
- Python 3.9
- PyTorch 1.10.0
- NVIDIA GPU (>= 12GB)
- CUDA 10.2 (optional)
conda create -n SNASNet
conda activate SNASNet
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
pip install scipy
Spikingjelly Installation (ref: https://github.com/fangwei123456/spikingjelly)
git clone https://github.com/fangwei123456/spikingjelly.git
cd spikingjelly
python setup.py install
- Arguments required for training and testing are contained in ``config.py```
- Here is an example of running an experiment on CIFAR100
- (if a user want to skip search process and use predefined architecgtur) A architecture can be parsed by
--cnt_mat 0302 0030 3003 0000
format
Example) Architecture and the corresponding connection matrix
- Run the following command
python search_snn.py --exp_name 'cifar100_backward' --dataset 'cifar100' --celltype 'backward' --batch_size 32 --num_search 5000
simple argument instruction
--exp_name: savefile name
--dataset: dataset for experiment
--celltype: find backward connections or forward connections
--num_search: number of architecture candidates for searching
(1) download pretrained parameters
CIFAR10: (link) to ./savemodel/save_cifar10_bw.pth.tar
CIFAR100: (link) to ./savemodel/save_cifar100_bw.pth.tar
(2) The above pretrained model is for
CIFAR10 architecture --cnt_mat 0303 0030 2002 0200
CIFAR100 architecture --cnt_mat 0302 0030 3003 0000
(3) Run the following command
python search_snn.py --dataset 'cifar10' --cnt_mat 0303 0030 2002 0200 --savemodel_pth './savemodel/save_cifar10_bw.pth.tar' --celltype 'backward' --second_avgpooling 4
python search_snn.py --dataset 'cifar100' --cnt_mat 0302 0030 3003 0000 --savemodel_pth './savemodel/save_cifar100_bw.pth.tar' --celltype 'backward'
Hamming distance measurement codes are referred from: https://github.com/BayesWatch/nas-without-training
Spiking Jelly reference: https://github.com/fangwei123456/spikingjelly
@article{kim2022neural,
title={Neural architecture search for spiking neural networks},
author={Kim, Youngeun and Li, Yuhang and Park, Hyoungseob and Venkatesha, Yeshwanth and Panda, Priyadarshini},
journal={arXiv preprint arXiv:2201.10355},
year={2022}
}