Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

ENAS and DRATS search space zoo #2589

Merged
merged 47 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
673cf3d
add darts cell and search space
tabVersion Jun 23, 2020
72f9f12
move to search_space_zoo
tabVersion Jun 23, 2020
99c841b
accept a cell to build full model
tabVersion Jun 24, 2020
e0e9e2c
fix compile error
tabVersion Jun 24, 2020
b55c6cd
bug fix
tabVersion Jun 24, 2020
3e162f4
change DartsCell signiture
tabVersion Jun 27, 2020
181e4c1
format code
tabVersion Jun 27, 2020
e35ff4b
change signature & inherit sequencial
tabVersion Jun 29, 2020
7896cb4
add search space example
tabVersion Jun 29, 2020
cd4eb1f
structure adjust & comment change
tabVersion Jun 29, 2020
736d196
clearify darts search space doc
tabVersion Jun 29, 2020
8c4f0bc
move dartsStackCells to example
tabVersion Jun 30, 2020
cf720c9
update docs
tabVersion Jul 3, 2020
823b0be
Merge branch 'master' into darts
tabVersion Jul 3, 2020
5d41f19
doc missing fix
tabVersion Jul 3, 2020
510dc38
Merge branch 'darts' of https://github.com/tabVersion/nni into darts
tabVersion Jul 3, 2020
03f7a28
doc fix
tabVersion Jul 6, 2020
40c517a
change code to fix doc
tabVersion Jul 6, 2020
8696f96
enas test
tabVersion Jul 6, 2020
1a46bf0
enas test
tabVersion Jul 6, 2020
40ab64a
enas test
tabVersion Jul 6, 2020
473e247
enas micro
tabVersion Jul 6, 2020
94f3eba
code format & doc fix & add example
tabVersion Jul 6, 2020
9efdb8a
refine doc
tabVersion Jul 7, 2020
5e2ed66
code format
tabVersion Jul 7, 2020
4cfdb10
add enas micro doc
tabVersion Jul 8, 2020
6a7a6ba
fix trailing whitespace
tabVersion Jul 8, 2020
8ddf8f1
add enas macro
tabVersion Jul 9, 2020
c0aecff
format doc
tabVersion Jul 9, 2020
c785024
fix doc
tabVersion Jul 9, 2020
0316d31
fix systax
tabVersion Jul 9, 2020
f6e9565
fix
tabVersion Jul 9, 2020
1b0c398
refine doc
tabVersion Jul 11, 2020
5b3dc94
refine doc
tabVersion Jul 13, 2020
f12df2c
update
tabVersion Jul 13, 2020
6cdfc5c
refine
tabVersion Jul 14, 2020
ec6ac2b
refine doc
tabVersion Jul 15, 2020
7ef03a6
refine doc
tabVersion Jul 16, 2020
199efb7
doc refine
tabVersion Jul 20, 2020
5b9c3ae
change sketch
tabVersion Jul 22, 2020
d5f63e2
change illustration
tabVersion Jul 24, 2020
05875f0
resolution fix
tabVersion Jul 24, 2020
2a5c434
update doc
tabVersion Jul 24, 2020
1d12bec
update doc
tabVersion Jul 24, 2020
de282c5
update doc
tabVersion Jul 24, 2020
63b20ce
doc
tabVersion Jul 24, 2020
8eb2afa
adjust menu sequence
tabVersion Jul 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# Licensed under the MIT license.

from .mutator import DartsMutator
from .trainer import DartsTrainer
from .trainer import DartsTrainer
2 changes: 2 additions & 0 deletions src/sdk/pynni/nni/nas/pytorch/search_space_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .darts_cell import DartsCell
from .darts_search_space import DartsSearchSpace
99 changes: 99 additions & 0 deletions src/sdk/pynni/nni/nas/pytorch/search_space_zoo/darts_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from collections import OrderedDict

import torch
import torch.nn as nn

import ops
from nni.nas.pytorch import mutables


class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
'''
builtin Darts Node structure

Attributes
---
node_id: str
num_prev_nodes: int
the number of previous nodes in this cell
channels: int
output channels
num_downsample_connect: int
'''
super().__init__()
self.ops = nn.ModuleList()
choice_keys = []
for i in range(num_prev_nodes):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
mutables.LayerChoice(OrderedDict([
("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)),
("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)),
("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)),
("sepconv3x3", ops.SepConv(channels, channels, 3, stride, 1, affine=False)),
("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)),
("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)),
("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False))
]), key=choice_keys[-1]))
self.drop_path = ops.DropPath()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))

def forward(self, prev_nodes):
assert len(self.ops) == len(prev_nodes)
out = [op(node) for op, node in zip(self.ops, prev_nodes)]
out = [self.drop_path(o) if o is not None else None for o in out]
return self.input_switch(out)


class DartsCell(nn.Module):
'''
Builtin Darts Cell structure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should state all the details here.

Attributes
---
n_nodes: int
the number of nodes contained in this cell
channels_pp: int
the number of previous previous cell's output channels
channels_p: int
the number of previous cell's output channels
channels: int
the number of output channels for each node
reduction_p: bool
Is previous cell a reduction cell
reduction: bool
is current cell a reduction cell
'''
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
super().__init__()
self.reduction = reduction
self.n_nodes = n_nodes

# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)

# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(2, self.n_nodes + 2):
self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
depth, channels, 2 if reduction else 0))

def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors = [self.preproc0(s0), self.preproc1(s1)]
for node in self.mutable_ops:
cur_tensor = node(tensors)
tensors.append(cur_tensor)

output = torch.cat(tensors[2:], dim=1)
return output
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from collections import OrderedDict

import torch
import torch.nn as nn

import ops
from nni.nas.pytorch import mutables


class DartsSearchSpace(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think DARTS search space is a cell rather than a full model. Suggest to changing into DartsStackedCells that accepts a cell class to build a full model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think DARTS search space is a cell rather than a full model. Suggest to changing into DartsStackedCells that accepts a cell class to build a full model.

please review new commit 99c841b

'''
builtin Darts Search Space
Compared to Darts example, DartsSearchSpace removes Auxiliary Head, which
is considered as a trick rather than part of model.

Attributes
---
in_channels: int
the number of input channels
channels: int
the number of initial channels expected
n_classes: int
classes for final classification
n_layers: int
the number of cells contained in this network
n_nodes: int
the number of nodes contained in each cell
stem_multiplier: int
channels multiply coefficient when passing a cell
'''
def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4,
stem_multiplier=3):
super().__init__()
self.in_channels = in_channels
self.channels = channels
self.n_classes = n_classes
self.n_layers = n_layers

c_cur = stem_multiplier * self.channels
self.stem = nn.Sequential(
nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
nn.BatchNorm2d(c_cur)
)

# for the first cell, stem is used for both s0 and s1
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
channels_pp, channels_p, c_cur = c_cur, c_cur, channels

self.cells = nn.ModuleList()
reduction_p, reduction = False, False
for i in range(n_layers):
reduction_p, reduction = reduction, False
# Reduce featuremap size and double channels in 1/3 and 2/3 layer.
if i in [n_layers // 3, 2 * n_layers // 3]:
c_cur *= 2
reduction = True

cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
self.cells.append(cell)
c_cur_out = c_cur * n_nodes
channels_pp, channels_p = channels_p, c_cur_out

self.gap = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(channels_p, n_classes)

def forward(self, x):
s0 = s1 = self.stem(x)

for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1)

out = self.gap(s1)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)

return logits

def drop_path_prob(self, p):
for module in self.modules():
if isinstance(module, ops.DropPath):
module.p = p
136 changes: 136 additions & 0 deletions src/sdk/pynni/nni/nas/pytorch/search_space_zoo/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
import torch.nn as nn


class DropPath(nn.Module):
def __init__(self, p=0.):
"""
Drop path with probability.

Parameters
----------
p : float
Probability of an path to be zeroed.
"""
super().__init__()
self.p = p

def forward(self, x):
if self.training and self.p > 0.:
keep_prob = 1. - self.p
# per data point mask
mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob)
return x / keep_prob * mask

return x


class PoolBN(nn.Module):
"""
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
"""
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
super().__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()

self.bn = nn.BatchNorm2d(C, affine=affine)

def forward(self, x):
out = self.pool(x)
out = self.bn(out)
return out


class StdConv(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to keep all convolutions in a separate file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Custom ops will be implemented later. I dont think it is appropriate to expose these builtin ops now.

"""
Standard conv: ReLU - Conv - BN
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can actually inheirt nn.Sequential here and use self.add_module.

nn.ReLU(),
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)

def forward(self, x):
return self.net(x)


class FacConv(nn.Module):
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)

def forward(self, x):
return self.net(x)


class DilConv(nn.Module):
"""
(Dilated) depthwise separable conv.
ReLU - (Dilated) depthwise separable - Pointwise - BN.
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
bias=False),
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)

def forward(self, x):
return self.net(x)


class SepConv(nn.Module):
"""
Depthwise separable conv.
DilConv(dilation=1) * 2.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
)

def forward(self, x):
return self.net(x)


class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise (stride=2).
"""
def __init__(self, C_in, C_out, affine=True):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)

def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out