-
Notifications
You must be signed in to change notification settings - Fork 1.8k
ENAS and DRATS search space zoo #2589
Changes from 2 commits
673cf3d
72f9f12
99c841b
e0e9e2c
b55c6cd
3e162f4
181e4c1
e35ff4b
7896cb4
cd4eb1f
736d196
8c4f0bc
cf720c9
823b0be
5d41f19
510dc38
03f7a28
40c517a
8696f96
1a46bf0
40ab64a
473e247
94f3eba
9efdb8a
5e2ed66
4cfdb10
6a7a6ba
8ddf8f1
c0aecff
c785024
0316d31
f6e9565
1b0c398
5b3dc94
f12df2c
6cdfc5c
ec6ac2b
7ef03a6
199efb7
5b9c3ae
d5f63e2
05875f0
2a5c434
1d12bec
de282c5
63b20ce
8eb2afa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .darts_cell import DartsCell | ||
from .darts_search_space import DartsSearchSpace |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from collections import OrderedDict | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
import ops | ||
from nni.nas.pytorch import mutables | ||
|
||
|
||
class Node(nn.Module): | ||
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect): | ||
''' | ||
builtin Darts Node structure | ||
|
||
Attributes | ||
--- | ||
node_id: str | ||
num_prev_nodes: int | ||
the number of previous nodes in this cell | ||
channels: int | ||
output channels | ||
num_downsample_connect: int | ||
''' | ||
super().__init__() | ||
self.ops = nn.ModuleList() | ||
choice_keys = [] | ||
for i in range(num_prev_nodes): | ||
stride = 2 if i < num_downsample_connect else 1 | ||
choice_keys.append("{}_p{}".format(node_id, i)) | ||
self.ops.append( | ||
mutables.LayerChoice(OrderedDict([ | ||
("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)), | ||
("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)), | ||
("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)), | ||
("sepconv3x3", ops.SepConv(channels, channels, 3, stride, 1, affine=False)), | ||
("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)), | ||
("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)), | ||
("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)) | ||
]), key=choice_keys[-1])) | ||
self.drop_path = ops.DropPath() | ||
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) | ||
|
||
def forward(self, prev_nodes): | ||
assert len(self.ops) == len(prev_nodes) | ||
out = [op(node) for op, node in zip(self.ops, prev_nodes)] | ||
out = [self.drop_path(o) if o is not None else None for o in out] | ||
return self.input_switch(out) | ||
|
||
|
||
class DartsCell(nn.Module): | ||
''' | ||
Builtin Darts Cell structure. | ||
|
||
Attributes | ||
--- | ||
n_nodes: int | ||
the number of nodes contained in this cell | ||
channels_pp: int | ||
the number of previous previous cell's output channels | ||
channels_p: int | ||
the number of previous cell's output channels | ||
channels: int | ||
the number of output channels for each node | ||
reduction_p: bool | ||
Is previous cell a reduction cell | ||
reduction: bool | ||
is current cell a reduction cell | ||
''' | ||
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): | ||
super().__init__() | ||
self.reduction = reduction | ||
self.n_nodes = n_nodes | ||
|
||
# If previous cell is reduction cell, current input size does not match with | ||
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. | ||
if reduction_p: | ||
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False) | ||
else: | ||
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False) | ||
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False) | ||
|
||
# generate dag | ||
self.mutable_ops = nn.ModuleList() | ||
for depth in range(2, self.n_nodes + 2): | ||
self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth), | ||
depth, channels, 2 if reduction else 0)) | ||
|
||
def forward(self, s0, s1): | ||
# s0, s1 are the outputs of previous previous cell and previous cell, respectively. | ||
tensors = [self.preproc0(s0), self.preproc1(s1)] | ||
for node in self.mutable_ops: | ||
cur_tensor = node(tensors) | ||
tensors.append(cur_tensor) | ||
|
||
output = torch.cat(tensors[2:], dim=1) | ||
return output |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from collections import OrderedDict | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
import ops | ||
from nni.nas.pytorch import mutables | ||
|
||
|
||
class DartsSearchSpace(nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think DARTS search space is a cell rather than a full model. Suggest to changing into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. on the way There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
please review new commit 99c841b |
||
''' | ||
builtin Darts Search Space | ||
Compared to Darts example, DartsSearchSpace removes Auxiliary Head, which | ||
is considered as a trick rather than part of model. | ||
|
||
Attributes | ||
--- | ||
in_channels: int | ||
the number of input channels | ||
channels: int | ||
the number of initial channels expected | ||
n_classes: int | ||
classes for final classification | ||
n_layers: int | ||
the number of cells contained in this network | ||
n_nodes: int | ||
the number of nodes contained in each cell | ||
stem_multiplier: int | ||
channels multiply coefficient when passing a cell | ||
''' | ||
def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4, | ||
stem_multiplier=3): | ||
super().__init__() | ||
self.in_channels = in_channels | ||
self.channels = channels | ||
self.n_classes = n_classes | ||
self.n_layers = n_layers | ||
|
||
c_cur = stem_multiplier * self.channels | ||
self.stem = nn.Sequential( | ||
nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), | ||
nn.BatchNorm2d(c_cur) | ||
) | ||
|
||
# for the first cell, stem is used for both s0 and s1 | ||
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. | ||
channels_pp, channels_p, c_cur = c_cur, c_cur, channels | ||
|
||
self.cells = nn.ModuleList() | ||
reduction_p, reduction = False, False | ||
for i in range(n_layers): | ||
reduction_p, reduction = reduction, False | ||
# Reduce featuremap size and double channels in 1/3 and 2/3 layer. | ||
if i in [n_layers // 3, 2 * n_layers // 3]: | ||
c_cur *= 2 | ||
reduction = True | ||
|
||
cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction) | ||
self.cells.append(cell) | ||
c_cur_out = c_cur * n_nodes | ||
channels_pp, channels_p = channels_p, c_cur_out | ||
|
||
self.gap = nn.AdaptiveAvgPool2d(1) | ||
self.linear = nn.Linear(channels_p, n_classes) | ||
|
||
def forward(self, x): | ||
s0 = s1 = self.stem(x) | ||
|
||
for i, cell in enumerate(self.cells): | ||
s0, s1 = s1, cell(s0, s1) | ||
|
||
out = self.gap(s1) | ||
out = out.view(out.size(0), -1) # flatten | ||
logits = self.linear(out) | ||
|
||
return logits | ||
|
||
def drop_path_prob(self, p): | ||
for module in self.modules(): | ||
if isinstance(module, ops.DropPath): | ||
module.p = p |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class DropPath(nn.Module): | ||
def __init__(self, p=0.): | ||
""" | ||
Drop path with probability. | ||
|
||
Parameters | ||
---------- | ||
p : float | ||
Probability of an path to be zeroed. | ||
""" | ||
super().__init__() | ||
self.p = p | ||
|
||
def forward(self, x): | ||
if self.training and self.p > 0.: | ||
keep_prob = 1. - self.p | ||
# per data point mask | ||
mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob) | ||
return x / keep_prob * mask | ||
|
||
return x | ||
|
||
|
||
class PoolBN(nn.Module): | ||
""" | ||
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`. | ||
""" | ||
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): | ||
super().__init__() | ||
if pool_type.lower() == 'max': | ||
self.pool = nn.MaxPool2d(kernel_size, stride, padding) | ||
elif pool_type.lower() == 'avg': | ||
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) | ||
else: | ||
raise ValueError() | ||
|
||
self.bn = nn.BatchNorm2d(C, affine=affine) | ||
|
||
def forward(self, x): | ||
out = self.pool(x) | ||
out = self.bn(out) | ||
return out | ||
|
||
|
||
class StdConv(nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggest to keep all convolutions in a separate file. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Custom ops will be implemented later. I dont think it is appropriate to expose these builtin ops now. |
||
""" | ||
Standard conv: ReLU - Conv - BN | ||
""" | ||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | ||
super().__init__() | ||
self.net = nn.Sequential( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can actually inheirt |
||
nn.ReLU(), | ||
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False), | ||
nn.BatchNorm2d(C_out, affine=affine) | ||
) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
|
||
class FacConv(nn.Module): | ||
""" | ||
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN | ||
""" | ||
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): | ||
super().__init__() | ||
self.net = nn.Sequential( | ||
nn.ReLU(), | ||
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False), | ||
nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False), | ||
nn.BatchNorm2d(C_out, affine=affine) | ||
) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
|
||
class DilConv(nn.Module): | ||
""" | ||
(Dilated) depthwise separable conv. | ||
ReLU - (Dilated) depthwise separable - Pointwise - BN. | ||
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field. | ||
""" | ||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): | ||
super().__init__() | ||
self.net = nn.Sequential( | ||
nn.ReLU(), | ||
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, | ||
bias=False), | ||
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), | ||
nn.BatchNorm2d(C_out, affine=affine) | ||
) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
|
||
class SepConv(nn.Module): | ||
""" | ||
Depthwise separable conv. | ||
DilConv(dilation=1) * 2. | ||
""" | ||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | ||
super().__init__() | ||
self.net = nn.Sequential( | ||
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine), | ||
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine) | ||
) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
|
||
class FactorizedReduce(nn.Module): | ||
""" | ||
Reduce feature map size by factorized pointwise (stride=2). | ||
""" | ||
def __init__(self, C_in, C_out, affine=True): | ||
super().__init__() | ||
self.relu = nn.ReLU() | ||
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) | ||
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) | ||
self.bn = nn.BatchNorm2d(C_out, affine=affine) | ||
|
||
def forward(self, x): | ||
x = self.relu(x) | ||
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) | ||
out = self.bn(out) | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should state all the details here.