Warning
SPSN has evolved into the ParaLIF (Parallelizable Leaky-Integrate-and-Fire) neuron. ParaLIF allows more stochastic and deterministic spiking functions. A recurrent version is also available. Visit https://github.com/NECOTIS/Parallelizable-Leaky-Integrate-and-Fire-Neuron
This repository contains code for simulating the proposed SPSN to accelerate training of spiking neural networks (SNN). The SPSN is compared to Leaky Integrate and Fire (LIF) neuron on the Spiking Heidelberg Digits (SHD) dataset. This repository consists of a few key components:
-
datasets.py
: This module provides a simple interface for loading and accessing training and test datasets. -
network.py
: This module contains the implementation of the neural network itself, including code for training and evaluating the network. -
run.py
: This is the main entry point for running the simulation. It provides a simple command-line interface for specifying various options. -
datasets
directory: This directory contains training and test datasets as hdf5 files. The SHD dataset needs to be downloaded to this directory from https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/ -
neurons
directory: This directory contains implementations for the two neurons types, extending the base class inbase.py
. The available models are:lif.py
: The Leaky Integrate-and-Fire modelspsn.py
: The Stochastic Parallelizable Spiking Neuron model. It can be simulated with the Sigmoid-Bernouilli firing mode (SPSP-SB) or with the Gumbel Softmax firing mode (SPSN-GS).
-
outputs
directory: This directory contains outputs generated by the simulation.
The run.py
script can be run using various arguments. The following are available:
--seed
: Random seed for reproducibility.--dataset
: The dataset to use for training, currently onlyheidelberg
is supported.--neuron
: The neuron model to use for training, options includeLIF
,SPSN-SB
,SPSN-GS
, andNon-Spiking
. TheNon-Spiking
neuron is a traditional neuron followed by a ReLu activation.--nb_epochs
: The number of training epochs.--tau_mem
: The neuron membrane time constant.--tau_syn
: The neuron synaptic current time constant.--batch_size
: The batch size for training.--hidden_size
: The number of neurons in the hidden layer.--nb_layers
: The number of hidden layers.--reg_thr
: The spiking frequency regularization threshold.--loss_mode
: The mode for computing the loss, options includelast
,max
, andmean
.--data_augmentation
: Whether to use data augmentation during training, options includeTrue
andFalse
.--h_shift
: The random shift factor for data augmentation.--scale
: The random scale factor for data augmentation.--dir
: The directory to save the results.--save_model
: Whether to save the trained model, options includeTrue
andFalse
.
To run the code in the basic mode, the following commands can be used.
python run.py --seed 0 --neuron 'LIF'
python run.py --seed 0 --neuron 'SPSN-SB'
python run.py --seed 0 --neuron 'SPSN-GS'
python run.py --seed 0 --neuron 'Non-Spiking'
To add data augmentation when training, the following commands can be used.
python run.py --seed 0 --neuron 'LIF' --data_augmentation True
python run.py --seed 0 --neuron 'SPSN-SB' --data_augmentation True
python run.py --seed 0 --neuron 'SPSN-GS' --data_augmentation True
python run.py --seed 0 --neuron 'Non-Spiking' --data_augmentation True
To reduce spiking frequency for SPSN the regularization can be used by the following commands:
python run.py --seed 0 --neuron 'SPSN-SB' --data_augmentation True --reg_thr 0.4
python run.py --seed 0 --neuron 'SPSN-GS' --data_augmentation True --reg_thr 0.1
The results achieved for the commands listed above are summurized in the following tables :
- Classification accuracy for the test set :
Neuron | Basic | Data augmentation | Data augmentation + Regularization |
---|---|---|---|
LIF | 71.37% | 83.03% | - |
SPSN-SB | 77.16 % | 86.08 % | 89.70 % |
SPSN-GS | 75.66 % | 86.08 % | 89.39 % |
Non-Spiking | 71.82 % | 66.07 % | - |
- 1 epoch training duration :
Neuron | Basic | Data augmentation | Data augmentation + Regularization |
---|---|---|---|
LIF | 252.2 s | 261.7 s | - |
SPSN-SB | 5.5 s | 10.7 s | 10.7 s |
SPSN-GS | 6.8 s | 12.2 s | 12.2 s |
Non-Spiking | 1.5 s | 6.7 s | - |
The required librairies to run the code are :
- h5py
- numpy
- torch
- torchvision
- tqdm