From 5f5e8e83c24661f28bf43719dd67fa8a240d7408 Mon Sep 17 00:00:00 2001 From: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Date: Wed, 24 Mar 2021 13:25:36 +0800 Subject: [PATCH] [Refactoring] Add Caffe2Xavier Initializer (#902) * [Refactoring] Add Caffe2Xavier Initializer * fix lint --- mmcv/cnn/__init__.py | 7 ++++--- mmcv/cnn/utils/__init__.py | 8 ++++---- mmcv/cnn/utils/weight_init.py | 16 ++++++++++++++++ tests/test_cnn/test_weight_init.py | 18 ++++++++++++++---- 4 files changed, 38 insertions(+), 11 deletions(-) diff --git a/mmcv/cnn/__init__.py b/mmcv/cnn/__init__.py index 06f2980219..41cf85d4ca 100644 --- a/mmcv/cnn/__init__.py +++ b/mmcv/cnn/__init__.py @@ -13,8 +13,8 @@ build_upsample_layer, conv_ws_2d, is_norm) # yapf: enable from .resnet import ResNet, make_res_layer -from .utils import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit, - PretrainedInit, UniformInit, XavierInit, +from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit, + NormalInit, PretrainedInit, UniformInit, XavierInit, bias_init_with_prob, caffe2_xavier_init, constant_init, fuse_conv_bn, get_model_complexity_info, initialize, kaiming_init, normal_init, uniform_init, xavier_init) @@ -33,5 +33,6 @@ 'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit', - 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit' + 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', + 'Caffe2XavierInit' ] diff --git a/mmcv/cnn/utils/__init__.py b/mmcv/cnn/utils/__init__.py index 99ec08a786..18efa4135f 100644 --- a/mmcv/cnn/utils/__init__.py +++ b/mmcv/cnn/utils/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) Open-MMLab. All rights reserved. from .flops_counter import get_model_complexity_info from .fuse_conv_bn import fuse_conv_bn -from .weight_init import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit, - PretrainedInit, UniformInit, XavierInit, - bias_init_with_prob, caffe2_xavier_init, +from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit, + KaimingInit, NormalInit, PretrainedInit, UniformInit, + XavierInit, bias_init_with_prob, caffe2_xavier_init, constant_init, initialize, kaiming_init, normal_init, uniform_init, xavier_init) @@ -12,5 +12,5 @@ 'constant_init', 'kaiming_init', 'normal_init', 'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', - 'PretrainedInit' + 'PretrainedInit', 'Caffe2XavierInit' ] diff --git a/mmcv/cnn/utils/weight_init.py b/mmcv/cnn/utils/weight_init.py index f9880a1906..9125f6f549 100644 --- a/mmcv/cnn/utils/weight_init.py +++ b/mmcv/cnn/utils/weight_init.py @@ -298,6 +298,22 @@ def init(m): module.apply(init) +@INITIALIZERS.register_module(name='Caffe2Xavier') +class Caffe2XavierInit(KaimingInit): + # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch + # Acknowledgment to FAIR's internal code + def __init__(self, **kwargs): + super().__init__( + a=1, + mode='fan_in', + nonlinearity='leaky_relu', + distribution='uniform', + **kwargs) + + def __call__(self, module): + super().__call__(module) + + @INITIALIZERS.register_module(name='Pretrained') class PretrainedInit(object): """Initialize module by loading a pretrained model. diff --git a/tests/test_cnn/test_weight_init.py b/tests/test_cnn/test_weight_init.py index c0113b9545..2d5df451bc 100644 --- a/tests/test_cnn/test_weight_init.py +++ b/tests/test_cnn/test_weight_init.py @@ -6,10 +6,11 @@ import torch from torch import nn -from mmcv.cnn import (ConstantInit, KaimingInit, NormalInit, PretrainedInit, - UniformInit, XavierInit, bias_init_with_prob, - caffe2_xavier_init, constant_init, initialize, - kaiming_init, normal_init, uniform_init, xavier_init) +from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit, + PretrainedInit, UniformInit, XavierInit, + bias_init_with_prob, caffe2_xavier_init, constant_init, + initialize, kaiming_init, normal_init, uniform_init, + xavier_init) def test_constant_init(): @@ -219,6 +220,15 @@ def test_kaiminginit(): assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.)) +def test_caffe2xavierinit(): + """test Caffe2XavierInit.""" + model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) + func = Caffe2XavierInit(bias=0.1, layer='Conv2d') + func(model) + assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1)) + assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1)) + + class FooModule(nn.Module): def __init__(self):