forked from microsoft/nni
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #216 from microsoft/master
merge master
- Loading branch information
Showing
24 changed files
with
802 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
Lottery Ticket Hypothesis on NNI | ||
=== | ||
|
||
## Introduction | ||
|
||
The paper [The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks](https://arxiv.org/abs/1803.03635) is mainly a measurement and analysis paper, it delivers very interesting insights. To support it on NNI, we mainly implement the training approach for finding *winning tickets*. | ||
|
||
In this paper, the authors use the following process to prune a model, called *iterative prunning*: | ||
>1. Randomly initialize a neural network f(x;theta_0) (where theta_0 follows D_{theta}). | ||
>2. Train the network for j iterations, arriving at parameters theta_j. | ||
>3. Prune p% of the parameters in theta_j, creating a mask m. | ||
>4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0). | ||
>5. Repeat step 2, 3, and 4. | ||
If the configured final sparsity is P (e.g., 0.8) and there are n times iterative pruning, each iterative pruning prunes 1-(1-P)^(1/n) of the weights that survive the previous round. | ||
|
||
## Reproduce Results | ||
|
||
We try to reproduce the experiment result of the fully connected network on MNIST using the same configuration as in the paper. The code can be referred [here](https://github.com/microsoft/nni/tree/master/examples/model_compress/lottery_torch_mnist_fc.py). In this experiment, we prune 10 times, for each pruning we train the pruned model for 50 epochs. | ||
|
||
![](../../img/lottery_ticket_mnist_fc.png) | ||
|
||
The above figure shows the result of the fully connected network. `round0-sparsity-0.0` is the performance without pruning. Consistent with the paper, pruning around 80% also obtain similar performance compared to non-pruning, and converges a little faster. If pruning too much, e.g., larger than 94%, the accuracy becomes lower and convergence becomes a little slower. A little different from the paper, the trend of the data in the paper is relatively more clear. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.utils.data | ||
import torchvision.datasets as datasets | ||
import torchvision.transforms as transforms | ||
from nni.compression.torch import LotteryTicketPruner | ||
|
||
class fc1(nn.Module): | ||
|
||
def __init__(self, num_classes=10): | ||
super(fc1, self).__init__() | ||
self.classifier = nn.Sequential( | ||
nn.Linear(28*28, 300), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(300, 100), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(100, num_classes), | ||
) | ||
|
||
def forward(self, x): | ||
x = torch.flatten(x, 1) | ||
x = self.classifier(x) | ||
return x | ||
|
||
def train(model, train_loader, optimizer, criterion): | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
model.train() | ||
for batch_idx, (imgs, targets) in enumerate(train_loader): | ||
optimizer.zero_grad() | ||
imgs, targets = imgs.to(device), targets.to(device) | ||
output = model(imgs) | ||
train_loss = criterion(output, targets) | ||
train_loss.backward() | ||
optimizer.step() | ||
return train_loss.item() | ||
|
||
def test(model, test_loader, criterion): | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss | ||
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability | ||
correct += pred.eq(target.data.view_as(pred)).sum().item() | ||
test_loss /= len(test_loader.dataset) | ||
accuracy = 100. * correct / len(test_loader.dataset) | ||
return accuracy | ||
|
||
|
||
if __name__ == '__main__': | ||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) | ||
traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform) | ||
testdataset = datasets.MNIST('./data', train=False, transform=transform) | ||
train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=0, drop_last=False) | ||
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=0, drop_last=True) | ||
|
||
model = fc1().to("cuda" if torch.cuda.is_available() else "cpu") | ||
optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3) | ||
criterion = nn.CrossEntropyLoss() | ||
|
||
configure_list = [{ | ||
'prune_iterations': 10, | ||
'sparsity': 0.96, | ||
'op_types': ['default'] | ||
}] | ||
pruner = LotteryTicketPruner(model, configure_list, optimizer) | ||
pruner.compress() | ||
|
||
for i in pruner.get_prune_iterations(): | ||
pruner.prune_iteration_start() | ||
loss = 0 | ||
accuracy = 0 | ||
for epoch in range(50): | ||
loss = train(model, train_loader, optimizer, criterion) | ||
accuracy = test(model, test_loader, criterion) | ||
print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy)) | ||
print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy)) | ||
pruner.export_model('model.pth', 'mask.pth') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.