Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Extract controller from mutator to make offline decisions #1758

Merged
35 changes: 21 additions & 14 deletions examples/nas/darts/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn

import ops
from nni.nas.pytorch import mutables, darts
from nni.nas.pytorch import mutables


class AuxiliaryHead(nn.Module):
Expand Down Expand Up @@ -31,12 +31,14 @@ def forward(self, x):
return logits


class Node(darts.DartsNode):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect, drop_path_prob=0.):
super().__init__(node_id, limitation=2)
class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
super().__init__()
self.ops = nn.ModuleList()
choice_keys = []
for i in range(num_prev_nodes):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
mutables.LayerChoice(
[
Expand All @@ -48,18 +50,19 @@ def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect, dr
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False),
],
key="{}_p{}".format(node_id, i)))
self.drop_path = ops.DropPath_(drop_path_prob)
key=choice_keys[-1]))
self.drop_path = ops.DropPath_()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))

def forward(self, prev_nodes):
assert len(self.ops) == len(prev_nodes)
out = [op(node) for op, node in zip(self.ops, prev_nodes)]
return sum(self.drop_path(o) for o in out if o is not None)
return self.input_switch(out)
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved


class Cell(nn.Module):

def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction, drop_path_prob=0.):
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
super().__init__()
self.reduction = reduction
self.n_nodes = n_nodes
Expand All @@ -74,10 +77,9 @@ def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, redu

# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(self.n_nodes):
self.mutable_ops.append(Node("r{:d}_n{}".format(reduction, depth),
depth + 2, channels, 2 if reduction else 0,
drop_path_prob=drop_path_prob))
for depth in range(2, self.n_nodes + 2):
self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
depth, channels, 2 if reduction else 0))

def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
Expand All @@ -93,7 +95,7 @@ def forward(self, s0, s1):
class CNN(nn.Module):

def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
stem_multiplier=3, auxiliary=False, drop_path_prob=0.):
stem_multiplier=3, auxiliary=False):
super().__init__()
self.in_channels = in_channels
self.channels = channels
Expand All @@ -120,7 +122,7 @@ def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nod
c_cur *= 2
reduction = True

cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction, drop_path_prob=drop_path_prob)
cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
self.cells.append(cell)
c_cur_out = c_cur * n_nodes
channels_pp, channels_p = channels_p, c_cur_out
Expand All @@ -147,3 +149,8 @@ def forward(self, x):
if aux_logits is not None:
return logits, aux_logits
return logits

def drop_path_prob(self, p):
for module in self.modules():
if isinstance(module, ops.DropPath_):
module.p = p
142 changes: 142 additions & 0 deletions examples/nas/darts/retrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import logging
from argparse import ArgumentParser

import torch
import torch.nn as nn

import datasets
import utils
from model import CNN
from nni.nas.pytorch.fixed import FixedArchitecture
from nni.nas.pytorch.utils import AverageMeter

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def train(config, train_loader, model, archit, optimizer, criterion, epoch):
top1 = AverageMeter("top1")
top5 = AverageMeter("top5")
losses = AverageMeter("losses")

cur_step = epoch * len(train_loader)
cur_lr = optimizer.param_groups[0]['lr']
logger.info("Epoch %d LR %.6f", epoch, cur_lr)

model.train()

for step, (x, y) in enumerate(train_loader):
x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
bs = x.size(0)

optimizer.zero_grad()
logits, aux_logits = model(x)
loss = criterion(logits, y)
if config.aux_weight > 0.:
loss += config.aux_weight * criterion(aux_logits, y)
loss.backward()
# gradient clipping
nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
optimizer.step()

accuracy = utils.accuracy(logits, y, topk=(1, 5))
losses.update(loss.item(), bs)
top1.update(accuracy["acc1"], bs)
top5.update(accuracy["acc5"], bs)

if step % config.log_frequency == 0 or step == len(train_loader) - 1:
logger.info(
"Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
epoch + 1, config.epochs, step, len(train_loader) - 1, losses=losses,
top1=top1, top5=top5))

cur_step += 1

logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg))


def validate(config, valid_loader, model, archit, criterion, epoch, cur_step):
top1 = AverageMeter("top1")
top5 = AverageMeter("top5")
losses = AverageMeter("losses")

model.eval()

with torch.no_grad():
for step, (X, y) in enumerate(valid_loader):
X, y = X.cuda(non_blocking=True), y.cuda(non_blocking=True)
N = X.size(0)

logits = model(X)
loss = criterion(logits, y)

accuracy = utils.accuracy(logits, y, topk=(1, 5))
losses.update(loss.item(), N)
top1.update(accuracy["acc1"], N)
top5.update(accuracy["acc5"], N)

if step % config.log_frequency == 0 or step == len(valid_loader) - 1:
logger.info(
"Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
epoch + 1, config.epochs, step, len(valid_loader) - 1, losses=losses,
top1=top1, top5=top5))

logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg))

return top1.avg


if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=20, type=int)
parser.add_argument("--batch-size", default=96, type=int)
parser.add_argument("--log-frequency", default=10, type=int)
parser.add_argument("--epochs", default=600, type=int)
parser.add_argument("--aux-weight", default=0.4, type=float)
parser.add_argument("--drop-path-prob", default=0.2, type=float)
parser.add_argument("--workers", default=4)
parser.add_argument("--grad-clip", default=5., type=float)

args = parser.parse_args()
assert torch.cuda.is_available()
dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16)

model = CNN(32, 3, 36, 10, args.layers, auxiliary=True)
archit = FixedArchitecture(model, "./checkpoints/epoch_0.json")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove the return value archit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Architecture might need to be sent to GPU.

criterion = nn.CrossEntropyLoss()
model.cuda()
criterion.cuda()
# TODO: move architecture to cuda

optimizer = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6)

train_loader = torch.utils.data.DataLoader(dataset_train,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=True)
valid_loader = torch.utils.data.DataLoader(dataset_valid,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True)

best_top1 = 0.
for epoch in range(args.epochs):
drop_prob = args.drop_path_prob * epoch / args.epochs
model.drop_path_prob(drop_prob)

# training
train(args, train_loader, model, archit, optimizer, criterion, epoch)

# validation
cur_step = (epoch + 1) * len(train_loader)
top1 = validate(args, valid_loader, model, archit, criterion, epoch, cur_step)
best_top1 = max(best_top1, top1)

lr_scheduler.step()

logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
4 changes: 2 additions & 2 deletions examples/nas/darts/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=8, type=int)
parser.add_argument("--batch-size", default=96, type=int)
parser.add_argument("--batch-size", default=64, type=int)
parser.add_argument("--log-frequency", default=10, type=int)
parser.add_argument("--epochs", default=50, type=int)
args = parser.parse_args()
Expand All @@ -36,4 +36,4 @@
batch_size=args.batch_size,
log_frequency=args.log_frequency,
callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
trainer.train_and_validate()
trainer.train()
20 changes: 10 additions & 10 deletions examples/nas/enas/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class ENASLayer(mutables.MutableScope):

def __init__(self, key, num_prev_layers, in_filters, out_filters):
def __init__(self, key, prev_labels, in_filters, out_filters):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add docstring

super().__init__(key)
self.in_filters = in_filters
self.out_filters = out_filters
Expand All @@ -18,16 +18,16 @@ def __init__(self, key, num_prev_layers, in_filters, out_filters):
PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1)
])
if num_prev_layers > 0:
self.skipconnect = mutables.InputChoice(num_prev_layers, n_selected=None, reduction="sum")
if len(prev_labels) > 0:
self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None, reduction="sum")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should think about what if the inputs are not all from layerchoice but from normal layers, then how to specify choose_from

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normal layers should be wrapped with a mutable scope to gain the power from the keys. Maybe we can do some annotations.

else:
self.skipconnect = None
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)

def forward(self, prev_layers, prev_labels):
def forward(self, prev_layers):
out = self.mutable(prev_layers[-1])
if self.skipconnect is not None:
connection = self.skipconnect(prev_layers[:-1], tags=prev_labels)
connection = self.skipconnect(prev_layers[:-1])
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
if connection is not None:
out += connection
return self.batch_norm(out)
Expand All @@ -53,11 +53,12 @@ def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10,

self.layers = nn.ModuleList()
self.pool_layers = nn.ModuleList()
labels = []
for layer_id in range(self.num_layers):
labels.append("layer_{}".format(layer_id))
if layer_id in self.pool_layers_idx:
self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters))
self.layers.append(ENASLayer("layer_{}".format(layer_id), layer_id,
self.out_filters, self.out_filters))
self.layers.append(ENASLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters))

self.gap = nn.AdaptiveAvgPool2d(1)
self.dense = nn.Linear(self.out_filters, self.num_classes)
Expand All @@ -66,12 +67,11 @@ def forward(self, x):
bs = x.size(0)
cur = self.stem(x)

layers, labels = [cur], []
layers = [cur]

for layer_id in range(self.num_layers):
cur = self.layers[layer_id](layers, labels)
cur = self.layers[layer_id](layers)
layers.append(cur)
labels.append(self.layers[layer_id].key)
if layer_id in self.pool_layers_idx:
for i, layer in enumerate(layers):
layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer)
Expand Down
33 changes: 17 additions & 16 deletions examples/nas/enas/micro.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def forward(self, x):


class Cell(nn.Module):
def __init__(self, cell_name, num_prev_layers, channels):
def __init__(self, cell_name, prev_labels, channels):
super().__init__()
self.input_choice = mutables.InputChoice(num_prev_layers, n_selected=1, return_mask=True,
self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True,
key=cell_name + "_input")
self.op_choice = mutables.LayerChoice([
SepConvBN(channels, channels, 3, 1),
Expand All @@ -44,21 +44,21 @@ def __init__(self, cell_name, num_prev_layers, channels):
nn.Identity()
], key=cell_name + "_op")

def forward(self, prev_layers, prev_labels):
chosen_input, chosen_mask = self.input_choice(prev_layers, tags=prev_labels)
def forward(self, prev_layers):
chosen_input, chosen_mask = self.input_choice(prev_layers)
cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask


class Node(mutables.MutableScope):
def __init__(self, node_name, num_prev_layers, channels):
def __init__(self, node_name, prev_node_names, channels):
super().__init__(node_name)
self.cell_x = Cell(node_name + "_x", num_prev_layers, channels)
self.cell_y = Cell(node_name + "_y", num_prev_layers, channels)
self.cell_x = Cell(node_name + "_x", prev_node_names, channels)
self.cell_y = Cell(node_name + "_y", prev_node_names, channels)

def forward(self, prev_layers, prev_labels):
out_x, mask_x = self.cell_x(prev_layers, prev_labels)
out_y, mask_y = self.cell_y(prev_layers, prev_labels)
def forward(self, prev_layers):
out_x, mask_x = self.cell_x(prev_layers)
out_y, mask_y = self.cell_y(prev_layers)
return out_x + out_y, mask_x | mask_y


Expand Down Expand Up @@ -93,8 +93,11 @@ def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduc

self.num_nodes = num_nodes
name_prefix = "reduce" if reduction else "normal"
self.nodes = nn.ModuleList([Node("{}_node_{}".format(name_prefix, i),
i + 2, out_channels) for i in range(num_nodes)])
self.nodes = nn.ModuleList()
node_labels = ["prev1", "prev2"]
for i in range(num_nodes):
node_labels.append("{}_node_{}".format(name_prefix, i))
self.nodes.append(Node(node_labels[-1], node_labels[:-1], out_channels))
self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1), requires_grad=True)
self.bn = nn.BatchNorm2d(out_channels, affine=False)
self.reset_parameters()
Expand All @@ -106,14 +109,12 @@ def forward(self, pprev, prev):
pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev)

prev_nodes_out = [pprev_, prev_]
prev_nodes_labels = ["prev1", "prev2"]
nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device)
for i in range(self.num_nodes):
node_out, mask = self.nodes[i](prev_nodes_out, prev_nodes_labels)
node_out, mask = self.nodes[i](prev_nodes_out)
nodes_used_mask[:mask.size(0)] |= mask
prev_nodes_out.append(node_out)
prev_nodes_labels.append(self.nodes[i].key)


unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1)
unused_nodes = F.relu(unused_nodes)
conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :]
Expand Down
Loading