Skip to content

Latest commit

 

History

History
60 lines (38 loc) · 4.42 KB

approach1.md

File metadata and controls

60 lines (38 loc) · 4.42 KB

Using GAN to learn 'Good Pruning Masks'

Part 1 - Create a pipeline to generate mask using standard methods and retain good masks

Randomise existing pruning techniques a little to generate various masks for the same initialization and pruning ratio.
Colab - https://colab.research.google.com/drive/18u5n-ykZJrDqf1vYiPILnpwFSFf2gzjO?usp=sharing.

As a starting point, I have selected the LeNet5 model and MNIST datset for the experiments.
The following pruning techniques were explored.

  1. Global Channel Prunning - based on L1 norm of weights in each channel
  2. Global Channel Pruning - selecting channels at random

Experiments

First LeNet5 is trained on MNIST datset for few epochs to get the best performing model (accuracy = 0.989).

Global Channel Prunning - based on L1 norm of weights in each channel

The criteria for pruning a channel is based on (sum of L1 norms of its weights)/(total number of weights).
Varying the sparsity from 0 to 95%, the best model is pruned based on the above criteria and retrained on the train dataset for few epochs. Best model accuracy, pruned accuracy and retrained accuracy at every level of sparsity is calculated on the test dataset. The following figure summarizes the observations.

Unknown-2

Global Channel Pruning - selecting channels at random

Channels are selected at random and pruned until the desired sparsity is reached. Similar to above, pruning and retraining is performed for all sparsity levels ranging from 0-95% in steps of 5. The experiment is repeated twice to observe for any variations.

Unknown Unknown-3

To check how random pruning is performing, a total of 20 trials of pruning (with 90% sparsity) and retraining is performed.

Unknown-4

Observations

From the graphs, it can be seen that the retrained accuracy is either reaching ~1(best model) or is stuck at the pruned accuracy. It might be possible that during pruning, the criteria might be pruning all the channels of one of the layers and hence unable to improve any further.

To confirm this reasoning, the above experiments are performed again by keeping track of layer wise sparsities at every step/trial. The figures are shown below.

Global Channel Pruning - selecting channels at random.

Unknown-5 Unknown-6

Global Channel Pruning - selecting channels at random, 20 trials of pruning (with 90% sparsity).

Unknown-7 Unknown-8

A clear correlation can be seen from the above plots, whenever one of the layers is completely pruned, the model could not improve any further. To further strangthen the reasoning, a modified version of random pruning is implemented such that at least one channel in each layer is preserved. The result of 20 trials (with 90% sparsity) is presented below, indicating that if there's atleast one channel in every layer, the retraining gives near best performance.

Unknown-9 Unknown-10

From the above experiments, it is clear that any randomly pruned mask of LeNet5 which could retain atleast one channel in each layer could do the best job on MNIST and therefore it looks like any mask is equally good in this case and it does not make much sense to use them for training in GAN.

TO-DO

  1. Perform similar experiments for ResNet on ImageNet data.

Part 2 - Implement a GAN which could take in a given Network as parameter and generate probablistic mapping of channels for pruning