- July 28, 2021. MGD in Unsupervised Learning.
This implementation is based on the official PyTorch ImageNet training code, which supports two training modes DataParallel (DP) and DistributedDataParallel (DDP). MGD for object detection is also re-implemented in Detectron2 as an external project.
Note: T : teacher feature tensors. S : student feature tensors. dp : distance function for distillation. Ci: i-th channel.
@inproceedings{eccv20mgd,
title = {Matching Guided Distillation},
author = {Yue, Kaiyu and Deng, Jiangfan and Zhou, Feng},
booktitle = {European Conference on Computer Vision (ECCV)},
year = {2020}
}
- Python - 3.7
- PyTorch - 1.5.0 with torchvision - 0.6.0
- Detectron2 Tree - 369a57d333
We take using ResNet-50 to distill ResNet-18 as an example, as shown in the below figure.
Note: models are from torchvision.
Install OR-Tools by pip install ortools
.
The function exposes intermediate features and final output logits. The only thing to do is copy the original forward context and expose any tensors you want to work with for distillation. Reference.
def extract_feature(self, x, preReLU=False):
...
feat3 = self.layer3(x) # we expose layer3 output
x = self.layer4(feat3)
...
if not preReLU:
feat3 = F.relu(feat3)
return [feat3], x
The function exposes BN layers before the distillation position. Reference.
def get_bn_before_relu(self):
if isinstance(self.layer1[0], Bottleneck):
bn3 = self.layer3[-1].bn3
elif isinstance(self.layer1[0], BasicBlock):
bn3 = self.layer3[-1].bn2
else:
print('ResNet unknown block error !!!')
raise
return [bn3]
The function tells MGD the channel number of the intermediate feature maps. Reference.
def get_channel_num(self):
return [1024]
t_net = resnet50() # teacher model
s_net = resnet18() # student model
import mgd.builder
d_net = mgd.builder.MGDistiller(
t_net,
s_net,
ignore_inds=[],
reducer='amp',
sync_bn=False,
with_kd=True,
preReLU=True,
distributed=False, # DP mode: False | DDP mode: True
det=False # work within Detectron2
)
# init mgd params in the first start
mgd_update(train_loader, d_net)
# training loop
for epoch in range(total_epochs):
# UPDATE_FREQ can be set by yourself
if (epoch+1)%UPDATE_FREQ == 0:
mgd_update(train_loader, d_net)
Classification | Object Detecton | Unsupervised Learning.
We learn and use some part of codes from following projects. We thank these excellent works:
- A Comprehensive Overhaul of Feature Distillation, ICCV'19.
- Detectron2. FAIR's next-generation platform for object detection and segmentation.
MIT. See LICENSE for details.