Skip to content

Commit

Permalink
[Refactoring] Add Caffe2Xavier Initializer (#902)
Browse files Browse the repository at this point in the history
* [Refactoring] Add Caffe2Xavier Initializer

* fix lint
  • Loading branch information
MeowZheng authored Mar 24, 2021
1 parent 933b052 commit 5f5e8e8
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 11 deletions.
7 changes: 4 additions & 3 deletions mmcv/cnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
build_upsample_layer, conv_ws_2d, is_norm)
# yapf: enable
from .resnet import ResNet, make_res_layer
from .utils import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit,
PretrainedInit, UniformInit, XavierInit,
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
NormalInit, PretrainedInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init, constant_init,
fuse_conv_bn, get_model_complexity_info, initialize,
kaiming_init, normal_init, uniform_init, xavier_init)
Expand All @@ -33,5 +33,6 @@
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit'
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit'
]
8 changes: 4 additions & 4 deletions mmcv/cnn/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn
from .weight_init import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit,
PretrainedInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init,
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
KaimingInit, NormalInit, PretrainedInit, UniformInit,
XavierInit, bias_init_with_prob, caffe2_xavier_init,
constant_init, initialize, kaiming_init, normal_init,
uniform_init, xavier_init)

Expand All @@ -12,5 +12,5 @@
'constant_init', 'kaiming_init', 'normal_init', 'uniform_init',
'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS',
'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit',
'PretrainedInit'
'PretrainedInit', 'Caffe2XavierInit'
]
16 changes: 16 additions & 0 deletions mmcv/cnn/utils/weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,22 @@ def init(m):
module.apply(init)


@INITIALIZERS.register_module(name='Caffe2Xavier')
class Caffe2XavierInit(KaimingInit):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
def __init__(self, **kwargs):
super().__init__(
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
distribution='uniform',
**kwargs)

def __call__(self, module):
super().__call__(module)


@INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit(object):
"""Initialize module by loading a pretrained model.
Expand Down
18 changes: 14 additions & 4 deletions tests/test_cnn/test_weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import torch
from torch import nn

from mmcv.cnn import (ConstantInit, KaimingInit, NormalInit, PretrainedInit,
UniformInit, XavierInit, bias_init_with_prob,
caffe2_xavier_init, constant_init, initialize,
kaiming_init, normal_init, uniform_init, xavier_init)
from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
PretrainedInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init, constant_init,
initialize, kaiming_init, normal_init, uniform_init,
xavier_init)


def test_constant_init():
Expand Down Expand Up @@ -219,6 +220,15 @@ def test_kaiminginit():
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))


def test_caffe2xavierinit():
"""test Caffe2XavierInit."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = Caffe2XavierInit(bias=0.1, layer='Conv2d')
func(model)
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))


class FooModule(nn.Module):

def __init__(self):
Expand Down

0 comments on commit 5f5e8e8

Please sign in to comment.