-
Notifications
You must be signed in to change notification settings - Fork 1
/
mobilenet_v3_micro.py
90 lines (77 loc) · 3.01 KB
/
mobilenet_v3_micro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import re
from functools import partial
from torch.hub import load_state_dict_from_url
from torchvision.models.mobilenetv3 import InvertedResidualConfig, MobileNetV3, model_urls
from src.models.mobilenet_v3 import MobileNetSegmentV3
# See https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py#L239
def _mobilenet_v3_micro_conf(width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False):
reduce_divider = 2 if reduced_tail else 1
dilation = 2 if dilated else 1
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
inverted_residual_setting = [
bneck_conf(4, 3, 8, 8, True, "RE", 2, 1), # C1
bneck_conf(8, 3, 16, 8, False, "RE", 1, 1),
bneck_conf(8, 3, 24, 16, False, "RE", 2, 1), # C2
bneck_conf(16, 3, 24, 16, True, "RE", 1, 1), # C3
bneck_conf(16, 3, 64, 16, True, "RE", 1, 1),
bneck_conf(16, 3, 64, 16, True, "RE", 1, 1),
bneck_conf(16, 3, 48, 24, True, "RE", 1, 1),
bneck_conf(24, 3, 48, 24, True, "RE", 1, 1),
bneck_conf(24, 3, 96, 32 // reduce_divider, True, "RE", 2, dilation), # C4
bneck_conf(32 // reduce_divider, 3, 160 // reduce_divider, 32 // reduce_divider, True, "RE", 1, dilation),
bneck_conf(32 // reduce_divider, 3, 160 // reduce_divider, 32 // reduce_divider, True, "RE", 1, dilation),
]
last_channel = adjust_channels(32 // reduce_divider) # C5
return inverted_residual_setting, last_channel
def _mobilenet_v3_micro(
arch,
inverted_residual_setting,
last_channel,
pretrained=False,
progress=True):
model = MobileNetV3(inverted_residual_setting, last_channel)
if pretrained:
if model_urls.get(arch, None) is None:
raise ValueError(f"No checkpoint is available for model type {arch}")
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
# Remove weights that would cause shape mismatch
prefixes = r'features\.([2-9]|10|11|12)|classifier\.'
for key in list(state_dict.keys()):
if re.search(prefixes, key):
del state_dict[key]
model.load_state_dict(state_dict, strict=False)
return model
inverted_residual_setting, last_channel = _mobilenet_v3_micro_conf()
backbone = _mobilenet_v3_micro("mobilenet_v3_small", inverted_residual_setting, last_channel)
model = dict(
type=MobileNetSegmentV3,
backbone=backbone,
backbone_out_ch=last_channel,
input_size=(80, 48), # WxH
output_size=(20, 12), # WxH
)
solver = dict(
optim='SGD',
lr_scheduler='Cosine',
lr0=0.01,
lrf=0.01,
momentum=0.937,
weight_decay=0.0005,
warmup_epochs=3.0,
warmup_momentum=0.8,
warmup_bias_lr=0.1
)
data_aug = dict(
hsv_h=0.015,
hsv_s=0.7,
hsv_v=0.4,
degrees=0.0,
translate=0.1,
scale=0.5,
shear=0.0,
flipud=0.0,
fliplr=0.5,
mosaic=1.0,
mixup=0.0,
)