-
Notifications
You must be signed in to change notification settings - Fork 46
/
test_convnextlarge.py
44 lines (36 loc) · 2.02 KB
/
test_convnextlarge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from only_train_once import OTO
from backends import convnext_large
import unittest
import os
OUT_DIR = './cache'
class TestConvNextLarge(unittest.TestCase):
def test_sanity(self, dummy_input=torch.rand(1, 3, 224, 224)):
# layer_scale_init_value will disable a nn.Parameters gamma.
# The singleton parameter is not supported in OTOv3
# Pretrained convnext has such parameters, thereby need to skip a few node groups.
model = convnext_large(layer_scale_init_value=-1)
# model = convnext_tiny(pretrained=True)
oto = OTO(model, dummy_input)
oto.visualize(view=False, out_dir=OUT_DIR)
# Compute FLOP and param for full model.
full_flops = oto.compute_flops(in_million=True)['total']
full_num_params = oto.compute_num_params(in_million=True)
oto.random_set_zero_groups()
oto.construct_subnet(out_dir=OUT_DIR)
full_model = torch.load(oto.full_group_sparse_model_path)
compressed_model = torch.load(oto.compressed_model_path)
full_output = full_model(dummy_input)
compressed_output = compressed_model(dummy_input)
max_output_diff = torch.max(torch.abs(full_output - compressed_output))
print("Maximum output difference : ", max_output_diff.item())
# self.assertLessEqual(max_output_diff, 1e-4)
full_model_size = os.stat(oto.full_group_sparse_model_path)
compressed_model_size = os.stat(oto.compressed_model_path)
print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs")
print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs")
# Compute FLOP and param for pruned model after oto.construct_subnet()
pruned_flops = oto.compute_flops(in_million=True)['total']
pruned_num_params = oto.compute_num_params(in_million=True)
print("FLOP reduction (%) : ", 1.0 - pruned_flops / full_flops)
print("Param reduction (%) : ", 1.0 - pruned_num_params / full_num_params)