diff --git a/examples/nas/.gitignore b/examples/nas/.gitignore index 8705cba4d6..9ba06a7ca3 100644 --- a/examples/nas/.gitignore +++ b/examples/nas/.gitignore @@ -1 +1,2 @@ -data +data +checkpoints diff --git a/examples/nas/darts/model.py b/examples/nas/darts/model.py index 5c284b5a46..6a9afe6ff3 100644 --- a/examples/nas/darts/model.py +++ b/examples/nas/darts/model.py @@ -2,7 +2,7 @@ import torch.nn as nn import ops -from nni.nas.pytorch import mutables, darts +from nni.nas.pytorch import mutables class AuxiliaryHead(nn.Module): @@ -31,12 +31,14 @@ def forward(self, x): return logits -class Node(darts.DartsNode): - def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect, drop_path_prob=0.): - super().__init__(node_id, limitation=2) +class Node(nn.Module): + def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect): + super().__init__() self.ops = nn.ModuleList() + choice_keys = [] for i in range(num_prev_nodes): stride = 2 if i < num_downsample_connect else 1 + choice_keys.append("{}_p{}".format(node_id, i)) self.ops.append( mutables.LayerChoice( [ @@ -48,18 +50,19 @@ def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect, dr ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False), ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False), ], - key="{}_p{}".format(node_id, i))) - self.drop_path = ops.DropPath_(drop_path_prob) + key=choice_keys[-1])) + self.drop_path = ops.DropPath_() + self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) def forward(self, prev_nodes): assert len(self.ops) == len(prev_nodes) out = [op(node) for op, node in zip(self.ops, prev_nodes)] - return sum(self.drop_path(o) for o in out if o is not None) + return self.input_switch(out) class Cell(nn.Module): - def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction, drop_path_prob=0.): + def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): super().__init__() self.reduction = reduction self.n_nodes = n_nodes @@ -74,10 +77,9 @@ def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, redu # generate dag self.mutable_ops = nn.ModuleList() - for depth in range(self.n_nodes): - self.mutable_ops.append(Node("r{:d}_n{}".format(reduction, depth), - depth + 2, channels, 2 if reduction else 0, - drop_path_prob=drop_path_prob)) + for depth in range(2, self.n_nodes + 2): + self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth), + depth, channels, 2 if reduction else 0)) def forward(self, s0, s1): # s0, s1 are the outputs of previous previous cell and previous cell, respectively. @@ -93,7 +95,7 @@ def forward(self, s0, s1): class CNN(nn.Module): def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4, - stem_multiplier=3, auxiliary=False, drop_path_prob=0.): + stem_multiplier=3, auxiliary=False): super().__init__() self.in_channels = in_channels self.channels = channels @@ -120,7 +122,7 @@ def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nod c_cur *= 2 reduction = True - cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction, drop_path_prob=drop_path_prob) + cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction) self.cells.append(cell) c_cur_out = c_cur * n_nodes channels_pp, channels_p = channels_p, c_cur_out @@ -147,3 +149,8 @@ def forward(self, x): if aux_logits is not None: return logits, aux_logits return logits + + def drop_path_prob(self, p): + for module in self.modules(): + if isinstance(module, ops.DropPath_): + module.p = p diff --git a/examples/nas/darts/retrain.py b/examples/nas/darts/retrain.py new file mode 100644 index 0000000000..5c8fabf8d0 --- /dev/null +++ b/examples/nas/darts/retrain.py @@ -0,0 +1,143 @@ +import logging +from argparse import ArgumentParser + +import torch +import torch.nn as nn + +import datasets +import utils +from model import CNN +from nni.nas.pytorch.fixed import apply_fixed_architecture +from nni.nas.pytorch.utils import AverageMeter + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def train(config, train_loader, model, optimizer, criterion, epoch): + top1 = AverageMeter("top1") + top5 = AverageMeter("top5") + losses = AverageMeter("losses") + + cur_step = epoch * len(train_loader) + cur_lr = optimizer.param_groups[0]['lr'] + logger.info("Epoch %d LR %.6f", epoch, cur_lr) + + model.train() + + for step, (x, y) in enumerate(train_loader): + x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) + bs = x.size(0) + + optimizer.zero_grad() + logits, aux_logits = model(x) + loss = criterion(logits, y) + if config.aux_weight > 0.: + loss += config.aux_weight * criterion(aux_logits, y) + loss.backward() + # gradient clipping + nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + optimizer.step() + + accuracy = utils.accuracy(logits, y, topk=(1, 5)) + losses.update(loss.item(), bs) + top1.update(accuracy["acc1"], bs) + top5.update(accuracy["acc5"], bs) + + if step % config.log_frequency == 0 or step == len(train_loader) - 1: + logger.info( + "Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " + "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( + epoch + 1, config.epochs, step, len(train_loader) - 1, losses=losses, + top1=top1, top5=top5)) + + cur_step += 1 + + logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg)) + + +def validate(config, valid_loader, model, criterion, epoch, cur_step): + top1 = AverageMeter("top1") + top5 = AverageMeter("top5") + losses = AverageMeter("losses") + + model.eval() + + with torch.no_grad(): + for step, (X, y) in enumerate(valid_loader): + X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) + N = X.size(0) + + logits = model(X) + loss = criterion(logits, y) + + accuracy = utils.accuracy(logits, y, topk=(1, 5)) + losses.update(loss.item(), N) + top1.update(accuracy["acc1"], N) + top5.update(accuracy["acc5"], N) + + if step % config.log_frequency == 0 or step == len(valid_loader) - 1: + logger.info( + "Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " + "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( + epoch + 1, config.epochs, step, len(valid_loader) - 1, losses=losses, + top1=top1, top5=top5)) + + logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg)) + + return top1.avg + + +if __name__ == "__main__": + parser = ArgumentParser("darts") + parser.add_argument("--layers", default=20, type=int) + parser.add_argument("--batch-size", default=96, type=int) + parser.add_argument("--log-frequency", default=10, type=int) + parser.add_argument("--epochs", default=600, type=int) + parser.add_argument("--aux-weight", default=0.4, type=float) + parser.add_argument("--drop-path-prob", default=0.2, type=float) + parser.add_argument("--workers", default=4) + parser.add_argument("--grad-clip", default=5., type=float) + parser.add_argument("--arc-checkpoint", default="./checkpoints/epoch_0.json") + + args = parser.parse_args() + dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16) + + model = CNN(32, 3, 36, 10, args.layers, auxiliary=True) + apply_fixed_architecture(model, args.arc_checkpoint, device=device) + criterion = nn.CrossEntropyLoss() + + model.to(device) + criterion.to(device) + + optimizer = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6) + + train_loader = torch.utils.data.DataLoader(dataset_train, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=True) + valid_loader = torch.utils.data.DataLoader(dataset_valid, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True) + + best_top1 = 0. + for epoch in range(args.epochs): + drop_prob = args.drop_path_prob * epoch / args.epochs + model.drop_path_prob(drop_prob) + + # training + train(args, train_loader, model, optimizer, criterion, epoch) + + # validation + cur_step = (epoch + 1) * len(train_loader) + top1 = validate(args, valid_loader, model, criterion, epoch, cur_step) + best_top1 = max(best_top1, top1) + + lr_scheduler.step() + + logger.info("Final best Prec@1 = {:.4%}".format(best_top1)) diff --git a/examples/nas/darts/search.py b/examples/nas/darts/search.py index 75773cf5e0..02c720a60c 100644 --- a/examples/nas/darts/search.py +++ b/examples/nas/darts/search.py @@ -13,7 +13,7 @@ if __name__ == "__main__": parser = ArgumentParser("darts") parser.add_argument("--layers", default=8, type=int) - parser.add_argument("--batch-size", default=96, type=int) + parser.add_argument("--batch-size", default=64, type=int) parser.add_argument("--log-frequency", default=10, type=int) parser.add_argument("--epochs", default=50, type=int) args = parser.parse_args() @@ -36,4 +36,4 @@ batch_size=args.batch_size, log_frequency=args.log_frequency, callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")]) - trainer.train_and_validate() + trainer.train() diff --git a/examples/nas/enas/macro.py b/examples/nas/enas/macro.py index 48fcaaf03d..a9309f9079 100644 --- a/examples/nas/enas/macro.py +++ b/examples/nas/enas/macro.py @@ -6,7 +6,7 @@ class ENASLayer(mutables.MutableScope): - def __init__(self, key, num_prev_layers, in_filters, out_filters): + def __init__(self, key, prev_labels, in_filters, out_filters): super().__init__(key) self.in_filters = in_filters self.out_filters = out_filters @@ -18,16 +18,16 @@ def __init__(self, key, num_prev_layers, in_filters, out_filters): PoolBranch('avg', in_filters, out_filters, 3, 1, 1), PoolBranch('max', in_filters, out_filters, 3, 1, 1) ]) - if num_prev_layers > 0: - self.skipconnect = mutables.InputChoice(num_prev_layers, n_selected=None, reduction="sum") + if len(prev_labels) > 0: + self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None, reduction="sum") else: self.skipconnect = None self.batch_norm = nn.BatchNorm2d(out_filters, affine=False) - def forward(self, prev_layers, prev_labels): + def forward(self, prev_layers): out = self.mutable(prev_layers[-1]) if self.skipconnect is not None: - connection = self.skipconnect(prev_layers[:-1], tags=prev_labels) + connection = self.skipconnect(prev_layers[:-1]) if connection is not None: out += connection return self.batch_norm(out) @@ -53,11 +53,12 @@ def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10, self.layers = nn.ModuleList() self.pool_layers = nn.ModuleList() + labels = [] for layer_id in range(self.num_layers): + labels.append("layer_{}".format(layer_id)) if layer_id in self.pool_layers_idx: self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters)) - self.layers.append(ENASLayer("layer_{}".format(layer_id), layer_id, - self.out_filters, self.out_filters)) + self.layers.append(ENASLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters)) self.gap = nn.AdaptiveAvgPool2d(1) self.dense = nn.Linear(self.out_filters, self.num_classes) @@ -66,12 +67,11 @@ def forward(self, x): bs = x.size(0) cur = self.stem(x) - layers, labels = [cur], [] + layers = [cur] for layer_id in range(self.num_layers): - cur = self.layers[layer_id](layers, labels) + cur = self.layers[layer_id](layers) layers.append(cur) - labels.append(self.layers[layer_id].key) if layer_id in self.pool_layers_idx: for i, layer in enumerate(layers): layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer) diff --git a/examples/nas/enas/micro.py b/examples/nas/enas/micro.py index 209abf2405..fabd3919ca 100644 --- a/examples/nas/enas/micro.py +++ b/examples/nas/enas/micro.py @@ -32,9 +32,9 @@ def forward(self, x): class Cell(nn.Module): - def __init__(self, cell_name, num_prev_layers, channels): + def __init__(self, cell_name, prev_labels, channels): super().__init__() - self.input_choice = mutables.InputChoice(num_prev_layers, n_selected=1, return_mask=True, + self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True, key=cell_name + "_input") self.op_choice = mutables.LayerChoice([ SepConvBN(channels, channels, 3, 1), @@ -44,21 +44,21 @@ def __init__(self, cell_name, num_prev_layers, channels): nn.Identity() ], key=cell_name + "_op") - def forward(self, prev_layers, prev_labels): - chosen_input, chosen_mask = self.input_choice(prev_layers, tags=prev_labels) + def forward(self, prev_layers): + chosen_input, chosen_mask = self.input_choice(prev_layers) cell_out = self.op_choice(chosen_input) return cell_out, chosen_mask class Node(mutables.MutableScope): - def __init__(self, node_name, num_prev_layers, channels): + def __init__(self, node_name, prev_node_names, channels): super().__init__(node_name) - self.cell_x = Cell(node_name + "_x", num_prev_layers, channels) - self.cell_y = Cell(node_name + "_y", num_prev_layers, channels) + self.cell_x = Cell(node_name + "_x", prev_node_names, channels) + self.cell_y = Cell(node_name + "_y", prev_node_names, channels) - def forward(self, prev_layers, prev_labels): - out_x, mask_x = self.cell_x(prev_layers, prev_labels) - out_y, mask_y = self.cell_y(prev_layers, prev_labels) + def forward(self, prev_layers): + out_x, mask_x = self.cell_x(prev_layers) + out_y, mask_y = self.cell_y(prev_layers) return out_x + out_y, mask_x | mask_y @@ -93,8 +93,11 @@ def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduc self.num_nodes = num_nodes name_prefix = "reduce" if reduction else "normal" - self.nodes = nn.ModuleList([Node("{}_node_{}".format(name_prefix, i), - i + 2, out_channels) for i in range(num_nodes)]) + self.nodes = nn.ModuleList() + node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY] + for i in range(num_nodes): + node_labels.append("{}_node_{}".format(name_prefix, i)) + self.nodes.append(Node(node_labels[-1], node_labels[:-1], out_channels)) self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1), requires_grad=True) self.bn = nn.BatchNorm2d(out_channels, affine=False) self.reset_parameters() @@ -106,14 +109,12 @@ def forward(self, pprev, prev): pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev) prev_nodes_out = [pprev_, prev_] - prev_nodes_labels = ["prev1", "prev2"] nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device) for i in range(self.num_nodes): - node_out, mask = self.nodes[i](prev_nodes_out, prev_nodes_labels) + node_out, mask = self.nodes[i](prev_nodes_out) nodes_used_mask[:mask.size(0)] |= mask prev_nodes_out.append(node_out) - prev_nodes_labels.append(self.nodes[i].key) - + unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1) unused_nodes = F.relu(unused_nodes) conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :] diff --git a/examples/nas/enas/search.py b/examples/nas/enas/search.py index 6e1bdec34c..35bc930333 100644 --- a/examples/nas/enas/search.py +++ b/examples/nas/enas/search.py @@ -13,7 +13,7 @@ if __name__ == "__main__": parser = ArgumentParser("enas") parser.add_argument("--batch-size", default=128, type=int) - parser.add_argument("--log-frequency", default=1, type=int) + parser.add_argument("--log-frequency", default=10, type=int) parser.add_argument("--search-for", choices=["macro", "micro"], default="macro") args = parser.parse_args() @@ -43,5 +43,6 @@ num_epochs=num_epochs, dataset_train=dataset_train, dataset_valid=dataset_valid, - log_frequency=args.log_frequency) - trainer.train_and_validate() + log_frequency=args.log_frequency, + mutator=mutator) + trainer.train() diff --git a/src/sdk/pynni/nni/nas/pytorch/base_mutator.py b/src/sdk/pynni/nni/nas/pytorch/base_mutator.py index dd2b844d24..550e449dfc 100644 --- a/src/sdk/pynni/nni/nas/pytorch/base_mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/base_mutator.py @@ -1,70 +1,127 @@ import logging import torch.nn as nn - -from nni.nas.pytorch.mutables import Mutable +from nni.nas.pytorch.mutables import Mutable, MutableScope, InputChoice +from nni.nas.pytorch.utils import StructuredMutableTreeNode logger = logging.getLogger(__name__) class BaseMutator(nn.Module): + """ + A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing + callbacks that are called in ``forward`` in Mutables. + """ + def __init__(self, model): super().__init__() self.__dict__["model"] = model - self.before_parse_search_space() - self._parse_search_space() - self.after_parse_search_space() - - def before_parse_search_space(self): - pass - - def after_parse_search_space(self): - pass - - def _parse_search_space(self): - for name, mutable, _ in self.named_mutables(distinct=False): - mutable.name = name - mutable.set_mutator(self) + self._structured_mutables = self._parse_search_space(self.model) - def named_mutables(self, root=None, distinct=True): + def _parse_search_space(self, module, root=None, prefix="", memo=None, nested_detection=None): + if memo is None: + memo = set() if root is None: - root = self.model - # if distinct is true, the method will filter out those with duplicated keys - key2module = dict() - for name, module in root.named_modules(): + root = StructuredMutableTreeNode(None) + if module not in memo: + memo.add(module) if isinstance(module, Mutable): - module_distinct = False - if module.key in key2module: - assert key2module[module.key].similar(module), \ - "Mutable \"{}\" that share the same key must be similar to each other".format(module.key) - else: - module_distinct = True - key2module[module.key] = module - if distinct: - if module_distinct: - yield name, module - else: - yield name, module, module_distinct - - def __setattr__(self, key, value): - if key in ["model", "net", "network"]: - logger.warning("Think twice if you are including the network into mutator.") - return super().__setattr__(key, value) - + if nested_detection is not None: + raise RuntimeError("Cannot have nested search space. Error at {} in {}" + .format(module, nested_detection)) + module.name = prefix + module.set_mutator(self) + root = root.add_child(module) + if not isinstance(module, MutableScope): + nested_detection = module + if isinstance(module, InputChoice): + for k in module.choose_from: + if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]: + raise RuntimeError("'{}' required by '{}' not found in keys that appeared before, and is not NO_KEY." + .format(k, module.key)) + for name, submodule in module._modules.items(): + if submodule is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + self._parse_search_space(submodule, root, submodule_prefix, memo=memo, + nested_detection=nested_detection) + return root + + @property + def mutables(self): + return self._structured_mutables + + @property def forward(self, *inputs): - raise NotImplementedError("Mutator is not forward-able") + raise RuntimeError("Forward is undefined for mutators.") def enter_mutable_scope(self, mutable_scope): + """ + Callback when forward of a MutableScope is entered. + + Parameters + ---------- + mutable_scope: MutableScope + + Returns + ------- + None + """ pass def exit_mutable_scope(self, mutable_scope): + """ + Callback when forward of a MutableScope is exited. + + Parameters + ---------- + mutable_scope: MutableScope + + Returns + ------- + None + """ pass def on_forward_layer_choice(self, mutable, *inputs): + """ + Callbacks of forward in LayerChoice. + + Parameters + ---------- + mutable: LayerChoice + inputs: list of torch.Tensor + + Returns + ------- + tuple of torch.Tensor and torch.Tensor + output tensor and mask + """ raise NotImplementedError - def on_forward_input_choice(self, mutable, tensor_list, tags): + def on_forward_input_choice(self, mutable, tensor_list): + """ + Callbacks of forward in InputChoice. + + Parameters + ---------- + mutable: InputChoice + tensor_list: list of torch.Tensor + + Returns + ------- + tuple of torch.Tensor and torch.Tensor + output tensor and mask + """ raise NotImplementedError def export(self): + """ + Export the data of all decisions. This should output the decisions of all the mutables, so that the whole + network can be fully determined with these decisions for further training from scratch. + + Returns + ------- + dict + """ raise NotImplementedError diff --git a/src/sdk/pynni/nni/nas/pytorch/base_trainer.py b/src/sdk/pynni/nni/nas/pytorch/base_trainer.py index 1248cc09e2..db1b033073 100644 --- a/src/sdk/pynni/nni/nas/pytorch/base_trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/base_trainer.py @@ -12,5 +12,9 @@ def validate(self): raise NotImplementedError @abstractmethod - def train_and_validate(self): + def export(self, file): + raise NotImplementedError + + @abstractmethod + def checkpoint(self): raise NotImplementedError diff --git a/src/sdk/pynni/nni/nas/pytorch/callbacks.py b/src/sdk/pynni/nni/nas/pytorch/callbacks.py index 2a76b3dab8..83ae62cde0 100644 --- a/src/sdk/pynni/nni/nas/pytorch/callbacks.py +++ b/src/sdk/pynni/nni/nas/pytorch/callbacks.py @@ -1,9 +1,6 @@ -import json import logging import os -import torch - _logger = logging.getLogger(__name__) @@ -44,26 +41,11 @@ def on_epoch_end(self, epoch): class ArchitectureCheckpoint(Callback): - class TorchTensorEncoder(json.JSONEncoder): - def default(self, o): # pylint: disable=method-hidden - if isinstance(o, torch.Tensor): - olist = o.tolist() - if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)): - _logger.warning("Every element in %s is either 0 or 1. " - "You might consider convert it into bool.", olist) - return olist - return super().default(o) - def __init__(self, checkpoint_dir, every="epoch"): super().__init__() assert every == "epoch" self.checkpoint_dir = checkpoint_dir os.makedirs(self.checkpoint_dir, exist_ok=True) - def _export_to_file(self, file): - mutator_export = self.mutator.export() - with open(file, "w") as f: - json.dump(mutator_export, f, indent=2, sort_keys=True, cls=self.TorchTensorEncoder) - def on_epoch_end(self, epoch): - self._export_to_file(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))) + self.trainer.export(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))) diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py index 7f2c9f9675..3bf08d285c 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py @@ -1,3 +1,2 @@ from .mutator import DartsMutator -from .trainer import DartsTrainer -from .scope import DartsNode \ No newline at end of file +from .trainer import DartsTrainer \ No newline at end of file diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py index 589847d2b6..91d739c0a3 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py @@ -2,35 +2,47 @@ from torch import nn as nn from torch.nn import functional as F -from nni.nas.pytorch.mutables import LayerChoice from nni.nas.pytorch.mutator import Mutator -from .scope import DartsNode +from nni.nas.pytorch.mutables import LayerChoice, InputChoice class DartsMutator(Mutator): - - def after_parse_search_space(self): + def __init__(self, model): + super().__init__(model) self.choices = nn.ParameterDict() - for _, mutable in self.named_mutables(): + for mutable in self.mutables: if isinstance(mutable, LayerChoice): - self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(len(mutable) + 1)) + self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1)) - def on_calc_layer_choice_mask(self, mutable: LayerChoice): - return F.softmax(self.choices[mutable.key], dim=-1)[:-1] + def device(self): + for v in self.choices.values(): + return v.device - def export(self): - result = super().export() - for _, darts_node in self.named_mutables(): - if isinstance(darts_node, DartsNode): - keys, edges_max = [], [] # key of all the layer choices in current node, and their best edge weight - for _, choice in self.named_mutables(darts_node): - if isinstance(choice, LayerChoice): - keys.append(choice.key) - max_val, index = torch.max(result[choice.key], 0) - edges_max.append(max_val) - result[choice.key] = F.one_hot(index, num_classes=len(result[choice.key])).view(-1).bool() - _, topk_edge_indices = torch.topk(torch.tensor(edges_max).view(-1), darts_node.limitation) # pylint: disable=not-callable - for i, key in enumerate(keys): - if i not in topk_edge_indices: - result[key] = torch.zeros_like(result[key]) + def sample_search(self): + result = dict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1] + elif isinstance(mutable, InputChoice): + result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) + return result + + def sample_final(self): + result = dict() + edges_max = dict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0) + edges_max[mutable.key] = max_val + result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool() + for mutable in self.mutables: + if isinstance(mutable, InputChoice): + weights = torch.tensor([edges_max.get(src_key, 0.) for src_key in mutable.choose_from]) # pylint: disable=not-callable + _, topk_edge_indices = torch.topk(weights, mutable.n_chosen or mutable.n_candidates) + selected_multihot = [] + for i, src_key in enumerate(mutable.choose_from): + if i not in topk_edge_indices and src_key in result: + result[src_key] = torch.zeros_like(result[src_key]) # clear this choice to optimize calc graph + selected_multihot.append(i in topk_edge_indices) + result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable return result diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/scope.py b/src/sdk/pynni/nni/nas/pytorch/darts/scope.py deleted file mode 100644 index a2bf2b3cff..0000000000 --- a/src/sdk/pynni/nni/nas/pytorch/darts/scope.py +++ /dev/null @@ -1,11 +0,0 @@ -from nni.nas.pytorch.mutables import MutableScope - - -class DartsNode(MutableScope): - """ - At most `limitation` choice is activated in a `DartsNode` when exporting. - """ - - def __init__(self, key, limitation): - super().__init__(key) - self.limitation = limitation diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py index 464832eadf..c6b29de04a 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py @@ -4,7 +4,7 @@ from torch import nn as nn from nni.nas.pytorch.trainer import Trainer -from nni.nas.utils import AverageMeterGroup +from nni.nas.pytorch.utils import AverageMeterGroup from .mutator import DartsMutator @@ -13,9 +13,9 @@ def __init__(self, model, loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None): - super().__init__(model, loss, metrics, optimizer, num_epochs, - dataset_train, dataset_valid, batch_size, workers, device, log_frequency, - mutator if mutator is not None else DartsMutator(model), callbacks) + super().__init__(model, mutator if mutator is not None else DartsMutator(model), + loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, + batch_size, workers, device, log_frequency, callbacks) self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), 3.0E-4, betas=(0.5, 0.999), weight_decay=1.0E-3) n_train = len(self.dataset_train) @@ -31,6 +31,9 @@ def __init__(self, model, loss, metrics, batch_size=batch_size, sampler=valid_sampler, num_workers=workers) + self.test_loader = torch.utils.data.DataLoader(self.dataset_valid, + batch_size=batch_size, + num_workers=workers) def train_one_epoch(self, epoch): self.model.train() @@ -47,8 +50,8 @@ def train_one_epoch(self, epoch): # phase 1. child network step self.optimizer.zero_grad() - with self.mutator.forward_pass(): - logits = self.model(trn_X) + self.mutator.reset() + logits = self.model(trn_X) loss = self.loss(logits, trn_y) loss.backward() # gradient clipping @@ -76,10 +79,10 @@ def validate_one_epoch(self, epoch): self.mutator.eval() meters = AverageMeterGroup() with torch.no_grad(): - for step, (X, y) in enumerate(self.valid_loader): + self.mutator.reset() + for step, (X, y) in enumerate(self.test_loader): X, y = X.to(self.device), y.to(self.device) - with self.mutator.forward_pass(): - logits = self.model(X) + logits = self.model(X) metrics = self.metrics(logits, y) meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: @@ -93,8 +96,8 @@ def _unrolled_backward(self, trn_X, trn_y, val_X, val_y, backup_model, lr): v_model: backup model before this step lr: learning rate for virtual gradient step (same as net lr) """ - with self.mutator.forward_pass(): - loss = self.loss(self.model(val_X), val_y) + self.mutator.reset() + loss = self.loss(self.model(val_X), val_y) w_model = tuple(self.model.parameters()) w_ctrl = tuple(self.mutator.parameters()) w_grads = torch.autograd.grad(loss, w_model + w_ctrl) @@ -125,8 +128,8 @@ def _compute_hessian(self, model, dw, trn_X, trn_y): for p, d in zip(self.model.parameters(), dw): p += eps * d - with self.mutator.forward_pass(): - loss = self.loss(self.model(trn_X), trn_y) + self.mutator.reset() + loss = self.loss(self.model(trn_X), trn_y) if e > 0: dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w+) } elif e < 0: diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py index 3bd32459b4..9d9a176352 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch.nn.functional as F -from nni.nas.pytorch.mutables import LayerChoice from nni.nas.pytorch.mutator import Mutator +from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope class StackedLSTMCell(nn.Module): @@ -27,15 +27,14 @@ def forward(self, inputs, hidden): class EnasMutator(Mutator): def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False, skip_target=0.4, branch_bias=0.25): + super().__init__(model) self.lstm_size = lstm_size self.lstm_num_layers = lstm_num_layers self.tanh_constant = tanh_constant self.cell_exit_extra_step = cell_exit_extra_step self.skip_target = skip_target self.branch_bias = branch_bias - super().__init__(model) - def before_parse_search_space(self): self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False) self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) @@ -45,9 +44,8 @@ def before_parse_search_space(self): self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none") self.bias_dict = nn.ParameterDict() - def after_parse_search_space(self): self.max_layer_choice = 0 - for _, mutable in self.named_mutables(): + for mutable in self.mutables: if isinstance(mutable, LayerChoice): if self.max_layer_choice == 0: self.max_layer_choice = mutable.length @@ -64,8 +62,29 @@ def is_conv(choice): self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size) self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False) - def before_pass(self): - super().before_pass() + def sample_search(self): + self._initialize() + self._sample(self.mutables) + return self._choices + + def sample_final(self): + return self.sample_search() + + def _sample(self, tree): + mutable = tree.mutable + if isinstance(mutable, LayerChoice) and mutable.key not in self._choices: + self._choices[mutable.key] = self._sample_layer_choice(mutable) + elif isinstance(mutable, InputChoice) and mutable.key not in self._choices: + self._choices[mutable.key] = self._sample_input_choice(mutable) + for child in tree.children: + self._sample(child) + if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid: + if self.cell_exit_extra_step: + self._lstm_next_step() + self._mark_anchor(mutable.key) + + def _initialize(self): + self._choices = dict() self._anchors_hid = dict() self._inputs = self.g_emb.data self._c = [torch.zeros((1, self.lstm_size), @@ -84,7 +103,7 @@ def _lstm_next_step(self): def _mark_anchor(self, key): self._anchors_hid[key] = self._h[-1] - def on_calc_layer_choice_mask(self, mutable): + def _sample_layer_choice(self, mutable): self._lstm_next_step() logit = self.soft(self._h[-1]) if self.tanh_constant is not None: @@ -94,14 +113,14 @@ def on_calc_layer_choice_mask(self, mutable): branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) log_prob = self.cross_entropy_loss(logit, branch_id) self.sample_log_prob += torch.sum(log_prob) - entropy = (log_prob * torch.exp(-log_prob)).detach() + entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type self.sample_entropy += torch.sum(entropy) self._inputs = self.embedding(branch_id) return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1) - def on_calc_input_choice_mask(self, mutable, tags): + def _sample_input_choice(self, mutable): query, anchors = [], [] - for label in tags: + for label in mutable.choose_from: if label not in self._anchors_hid: self._lstm_next_step() self._mark_anchor(label) # empty loop, fill not found @@ -113,8 +132,8 @@ def on_calc_input_choice_mask(self, mutable, tags): if self.tanh_constant is not None: query = self.tanh_constant * torch.tanh(query) - if mutable.n_selected is None: - logit = torch.cat([-query, query], 1) + if mutable.n_chosen is None: + logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) skip_prob = torch.sigmoid(logit) @@ -123,19 +142,14 @@ def on_calc_input_choice_mask(self, mutable, tags): log_prob = self.cross_entropy_loss(logit, skip) self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0) else: - assert mutable.n_selected == 1, "Input choice must select exactly one or any in ENAS." + assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS." logit = query.view(1, -1) index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) - skip = F.one_hot(index).view(-1) + skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1) log_prob = self.cross_entropy_loss(logit, index) self._inputs = anchors[index.item()] self.sample_log_prob += torch.sum(log_prob) - entropy = (log_prob * torch.exp(-log_prob)).detach() + entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type self.sample_entropy += torch.sum(entropy) return skip.bool() - - def exit_mutable_scope(self, mutable_scope): - if self.cell_exit_extra_step: - self._lstm_next_step() - self._mark_anchor(mutable_scope.key) diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py index 7d3e493782..1ed302ac7b 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py @@ -2,7 +2,7 @@ import torch.optim as optim from nni.nas.pytorch.trainer import Trainer -from nni.nas.utils import AverageMeterGroup +from nni.nas.pytorch.utils import AverageMeterGroup from .mutator import EnasMutator @@ -12,9 +12,9 @@ def __init__(self, model, loss, metrics, reward_function, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4): - super().__init__(model, loss, metrics, optimizer, num_epochs, - dataset_train, dataset_valid, batch_size, workers, device, log_frequency, - mutator if mutator is not None else EnasMutator(model), callbacks) + super().__init__(model, mutator if mutator is not None else EnasMutator(model), + loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, + batch_size, workers, device, log_frequency, callbacks) self.reward_function = reward_function self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr) @@ -52,8 +52,9 @@ def train_one_epoch(self, epoch): x, y = x.to(self.device), y.to(self.device) self.optimizer.zero_grad() - with self.mutator.forward_pass(): - logits = self.model(x) + with torch.no_grad(): + self.mutator.reset() + logits = self.model(x) if isinstance(logits, tuple): logits, aux_logits = logits @@ -81,7 +82,8 @@ def train_one_epoch(self, epoch): for step, (x, y) in enumerate(self.valid_loader): x, y = x.to(self.device), y.to(self.device) - with self.mutator.forward_pass(): + self.mutator.reset() + with torch.no_grad(): logits = self.model(x) metrics = self.metrics(logits, y) reward = self.reward_function(logits, y) @@ -107,9 +109,9 @@ def train_one_epoch(self, epoch): self.mutator_optim.zero_grad() if self.log_frequency is not None and step % self.log_frequency == 0: - print("Mutator Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, - mutator_step // self.mutator_steps_aggregate, - self.mutator_steps, meters)) + print("RL Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, + mutator_step // self.mutator_steps_aggregate, + self.mutator_steps, meters)) mutator_step += 1 if mutator_step >= total_mutator_steps: break diff --git a/src/sdk/pynni/nni/nas/pytorch/fixed.py b/src/sdk/pynni/nni/nas/pytorch/fixed.py index 526d66b610..6b83aa0800 100644 --- a/src/sdk/pynni/nni/nas/pytorch/fixed.py +++ b/src/sdk/pynni/nni/nas/pytorch/fixed.py @@ -2,10 +2,12 @@ import torch +from nni.nas.pytorch.mutables import MutableScope from nni.nas.pytorch.mutator import Mutator class FixedArchitecture(Mutator): + def __init__(self, model, fixed_arc, strict=True): """ Initialize a fixed architecture mutator. @@ -20,39 +22,57 @@ def __init__(self, model, fixed_arc, strict=True): Force everything that appears in `fixed_arc` to be used at least once. """ super().__init__(model) - if isinstance(fixed_arc, str): - with open(fixed_arc, "r") as f: - fixed_arc = json.load(f.read()) self._fixed_arc = fixed_arc - self._strict = strict - - def _encode_tensor(self, data): - if isinstance(data, list): - if all(map(lambda o: isinstance(o, bool), data)): - return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable - else: - return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable - if isinstance(data, dict): - return {k: self._encode_tensor(v) for k, v in data.items()} - return data - - def before_pass(self): - self._unused_key = set(self._fixed_arc.keys()) - - def after_pass(self): - if self._strict: - if self._unused_key: - raise ValueError("{} are never used by the network. " - "Set strict=False if you want to disable this check.".format(self._unused_key)) - - def _check_key(self, key): - if key not in self._fixed_arc: - raise ValueError("\"{}\" is demanded by the network, but not found in saved architecture.".format(key)) - - def on_calc_layer_choice_mask(self, mutable): - self._check_key(mutable.key) - return self._fixed_arc[mutable.key] - - def on_calc_input_choice_mask(self, mutable, tags): - self._check_key(mutable.key) - return self._fixed_arc[mutable.key] + + mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)]) + fixed_arc_keys = set(self._fixed_arc.keys()) + if fixed_arc_keys - mutable_keys: + raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) + if mutable_keys - fixed_arc_keys: + raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) + + def sample_search(self): + return self._fixed_arc + + def sample_final(self): + return self._fixed_arc + + +def _encode_tensor(data, device): + if isinstance(data, list): + if all(map(lambda o: isinstance(o, bool), data)): + return torch.tensor(data, dtype=torch.bool, device=device) # pylint: disable=not-callable + else: + return torch.tensor(data, dtype=torch.float, device=device) # pylint: disable=not-callable + if isinstance(data, dict): + return {k: _encode_tensor(v, device) for k, v in data.items()} + return data + + +def apply_fixed_architecture(model, fixed_arc_path, device=None): + """ + Load architecture from `fixed_arc_path` and apply to model. + + Parameters + ---------- + model: torch.nn.Module + Model with mutables. + fixed_arc_path: str + Path to the JSON that stores the architecture. + device: torch.device + Architecture weights will be transfered to `device`. + + Returns + ------- + FixedArchitecture + """ + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if isinstance(fixed_arc_path, str): + with open(fixed_arc_path, "r") as f: + fixed_arc = json.load(f) + fixed_arc = _encode_tensor(fixed_arc, device) + architecture = FixedArchitecture(model, fixed_arc) + architecture.to(device) + architecture.reset() diff --git a/src/sdk/pynni/nni/nas/pytorch/mutables.py b/src/sdk/pynni/nni/nas/pytorch/mutables.py index 16b73b903d..79cde1cf3f 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutables.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutables.py @@ -1,6 +1,6 @@ import torch.nn as nn -from nni.nas.utils import global_mutable_counting +from nni.nas.pytorch.utils import global_mutable_counting class Mutable(nn.Module): @@ -37,7 +37,7 @@ def set_mutator(self, mutator): self.__dict__["mutator"] = mutator def forward(self, *inputs): - raise NotImplementedError("Mutable forward must be implemented.") + raise NotImplementedError @property def key(self): @@ -51,9 +51,6 @@ def name(self): def name(self, name): self._name = name - def similar(self, other): - return type(self) == type(other) - def _check_built(self): if not hasattr(self, "mutator"): raise ValueError( @@ -66,19 +63,17 @@ def __repr__(self): class MutableScope(Mutable): """ - Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope - is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch - corresponding events, and do status dump or update. + Mutable scope labels a subgraph/submodule to help mutators make better decisions. + Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope`` + and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update. """ def __init__(self, key): super().__init__(key=key) - def build(self): - self.mutator.on_init_mutable_scope(self) - def __call__(self, *args, **kwargs): try: + self._check_built() self.mutator.enter_mutable_scope(self) return super().__call__(*args, **kwargs) finally: @@ -93,43 +88,92 @@ def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None) self.reduction = reduction self.return_mask = return_mask - def __len__(self): - return len(self.choices) - def forward(self, *inputs): out, mask = self.mutator.on_forward_layer_choice(self, *inputs) if self.return_mask: return out, mask return out - def similar(self, other): - return type(self) == type(other) and self.length == other.length - class InputChoice(Mutable): - def __init__(self, n_candidates, n_selected=None, reduction="mean", return_mask=False, key=None): + """ + Input choice selects `n_chosen` inputs from `choose_from` (contains `n_candidates` keys). For beginners, + use `n_candidates` instead of `choose_from` is a safe option. To get the most power out of it, you might want to + know about `choose_from`. + + The keys in `choose_from` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones. + The keys are designed to be the keys of the sources. To help mutators make better decisions, + mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the + output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g., + ``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a + module/submodule, it needs to be annotated with a key: that's where a ``MutableScope`` is needed. + """ + + NO_KEY = "" + + def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, + reduction="mean", return_mask=False, key=None): + """ + Initialization. + + Parameters + ---------- + n_candidates: int + Number of inputs to choose from. + choose_from: list of str + List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled. + If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates` + number of empty string. + n_chosen: int + Recommended inputs to choose. If None, mutator is instructed to select any. + reduction: str + `mean`, `concat`, `sum` or `none`. + return_mask: bool + If `return_mask`, return output tensor and a mask. Otherwise return tensor only. + key: str + Key of the input choice. + """ super().__init__(key=key) + # precondition check + assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \ + "must be not None." + if choose_from is not None and n_candidates is None: + n_candidates = len(choose_from) + elif choose_from is None and n_candidates is not None: + choose_from = [self.NO_KEY] * n_candidates + assert n_candidates == len(choose_from), "Number of candidates must be equal to the length of `choose_from`." assert n_candidates > 0, "Number of candidates must be greater than 0." + assert n_chosen is None or 0 <= n_chosen <= n_candidates, "Expected selected number must be None or no more " \ + "than number of candidates." + self.n_candidates = n_candidates - self.n_selected = n_selected + self.choose_from = choose_from + self.n_chosen = n_chosen self.reduction = reduction self.return_mask = return_mask - def build(self): - self.mutator.on_init_input_choice(self) - - def forward(self, optional_inputs, tags=None): + def forward(self, optional_inputs): + """ + Forward method of LayerChoice. + + Parameters + ---------- + optional_inputs: list or dict + Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of + `choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as + `choose_from`. + + Returns + ------- + tuple of torch.Tensor and torch.Tensor or torch.Tensor + """ + optional_input_list = optional_inputs + if isinstance(optional_inputs, dict): + optional_input_list = [optional_inputs[tag] for tag in self.choose_from] + assert isinstance(optional_input_list, list), "Optional input list must be a list" assert len(optional_inputs) == self.n_candidates, \ "Length of the input list must be equal to number of candidates." - if tags is None: - tags = [""] * self.n_candidates - else: - assert len(tags) == self.n_candidates, "Length of tags must be equal to number of candidates." - out, mask = self.mutator.on_forward_input_choice(self, optional_inputs, tags) + out, mask = self.mutator.on_forward_input_choice(self, optional_input_list) if self.return_mask: return out, mask return out - - def similar(self, other): - return type(self) == type(other) and \ - self.n_candidates == other.n_candidates and self.n_selected and other.n_selected diff --git a/src/sdk/pynni/nni/nas/pytorch/mutator.py b/src/sdk/pynni/nni/nas/pytorch/mutator.py index 21d39545e7..80608c6925 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutator.py @@ -1,46 +1,59 @@ -from contextlib import contextmanager - import torch -import torch.nn as nn from nni.nas.pytorch.base_mutator import BaseMutator -class Mutator(BaseMutator, nn.Module): +class Mutator(BaseMutator): - def export(self): - if self._in_forward_pass: - raise RuntimeError("Still in forward pass. Exporting might induce incompleteness.") - if not self._cache: - raise RuntimeError("No running history found. You need to call your model at least once before exporting. " - "You might also want to check if there are no valid mutables in your model.") - return self._cache - - @contextmanager - def forward_pass(self): - self._in_forward_pass = True + def __init__(self, model): + super().__init__(model) self._cache = dict() - self.before_pass() - try: - yield self - finally: - self.after_pass() - self._in_forward_pass = False - def before_pass(self): - pass + def sample_search(self): + """ + Override to implement this method to iterate over mutables and make decisions. + + Returns + ------- + dict + A mapping from key of mutables to decisions. + """ + raise NotImplementedError + + def sample_final(self): + """ + Override to implement this method to iterate over mutables and make decisions that is final + for export and retraining. - def after_pass(self): - pass + Returns + ------- + dict + A mapping from key of mutables to decisions. + """ + raise NotImplementedError - def _check_in_forward_pass(self): - if not hasattr(self, "_in_forward_pass") or not self._in_forward_pass: - raise ValueError("Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call " - "super().before_pass() and after_pass() in your override method?") + def reset(self): + """ + Reset the mutator by call the `sample_search` to resample (for search). + + Returns + ------- + None + """ + self._cache = self.sample_search() + + def export(self): + """ + Resample (for final) and return results. + + Returns + ------- + dict + """ + return self.sample_final() def on_forward_layer_choice(self, mutable, *inputs): """ - Callback of layer choice forward. Override if you are an advanced user. On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers (either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`. @@ -54,18 +67,17 @@ def on_forward_layer_choice(self, mutable, *inputs): ------- tuple of torch.Tensor and torch.Tensor """ - self._check_in_forward_pass() def _map_fn(op, *inputs): return op(*inputs) - mask = self._cache.setdefault(mutable.key, self.on_calc_layer_choice_mask(mutable)) + mask = self._get_decision(mutable) + assert len(mask) == len(mutable.choices) out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask) return self._tensor_reduction(mutable.reduction, out), mask - def on_forward_input_choice(self, mutable, tensor_list, tags): + def on_forward_input_choice(self, mutable, tensor_list): """ - Callback of input choice forward. Override if you are an advanced user. On default, this method calls :meth:`on_calc_input_choice_mask` with `tags` to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the @@ -81,48 +93,11 @@ def on_forward_input_choice(self, mutable, tensor_list, tags): ------- tuple of torch.Tensor and torch.Tensor """ - self._check_in_forward_pass() - mask = self._cache.setdefault(mutable.key, self.on_calc_input_choice_mask(mutable, tags)) + mask = self._get_decision(mutable) + assert len(mask) == mutable.n_candidates out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) return self._tensor_reduction(mutable.reduction, out), mask - def on_calc_layer_choice_mask(self, mutable): - """ - Recommended to override. Calculate a mask tensor for a layer choice. - - Parameters - ---------- - mutable: LayerChoice - Corresponding layer choice object. - - Returns - ------- - torch.Tensor - Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool, - the numbers are treated as switch. - """ - raise NotImplementedError("Layer choice mask calculation must be implemented") - - def on_calc_input_choice_mask(self, mutable, tags): - """ - Recommended to override. Calculate a mask tensor for a input choice. - - Parameters - ---------- - mutable: InputChoice - Corresponding input choice object. - tags: list of string - The name of labels of input tensors given by user. Usually it's a - :class:`~nni.nas.pytorch.mutables.MutableScope` marked by user. - - Returns - ------- - torch.Tensor - Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool, - the numbers are treated as switch. - """ - raise NotImplementedError("Input choice mask calculation must be implemented") - def _select_with_mask(self, map_fn, candidates, mask): if "BoolTensor" in mask.type(): out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] @@ -146,3 +121,20 @@ def _tensor_reduction(self, reduction_type, tensor_list): if reduction_type == "concat": return torch.cat(tensor_list, dim=1) raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type)) + + def _get_decision(self, mutable): + """ + By default, this method checks whether `mutable.key` is already in the decision cache, + and returns the result without double-check. + + Parameters + ---------- + mutable: Mutable + + Returns + ------- + any + """ + if mutable.key not in self._cache: + raise ValueError("\"{}\" not found in decision cache.".format(mutable.key)) + return self._cache[mutable.key] diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py index 6e385b1170..da31b3cc69 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py @@ -11,14 +11,14 @@ class PdartsMutator(DartsMutator): - def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches=None): + def __init__(self, pdarts_epoch_index, pdarts_num_to_drop, switches=None): self.pdarts_epoch_index = pdarts_epoch_index self.pdarts_num_to_drop = pdarts_num_to_drop self.switches = switches - super(PdartsMutator, self).__init__(model) + super(PdartsMutator, self).__init__() - def before_build(self, model): + def before_build(self): self.choices = nn.ParameterDict() if self.switches is None: self.switches = {} diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py index 4d9c231143..d4ef2bbb8e 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py @@ -35,8 +35,7 @@ def train(self): layers = self.layers+self.pdarts_num_layers[epoch] model, loss, model_optim, _ = self.model_creator( layers, n_nodes) - mutator = PdartsMutator( - model, epoch, self.pdarts_num_to_drop, switches) + mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches) # pylint: disable=too-many-function-args self.trainer = DartsTrainer(model, loss=loss, optimizer=model_optim, mutator=mutator, **self.darts_parameters) diff --git a/src/sdk/pynni/nni/nas/pytorch/trainer.py b/src/sdk/pynni/nni/nas/pytorch/trainer.py index ab18e6c6e5..a4954a0747 100644 --- a/src/sdk/pynni/nni/nas/pytorch/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/trainer.py @@ -1,24 +1,39 @@ +import json +import logging from abc import abstractmethod import torch from .base_trainer import BaseTrainer +_logger = logging.getLogger(__name__) + + +class TorchTensorEncoder(json.JSONEncoder): + def default(self, o): # pylint: disable=method-hidden + if isinstance(o, torch.Tensor): + olist = o.tolist() + if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)): + _logger.warning("Every element in %s is either 0 or 1. " + "You might consider convert it into bool.", olist) + return olist + return super().default(o) + class Trainer(BaseTrainer): - def __init__(self, model, loss, metrics, optimizer, num_epochs, - dataset_train, dataset_valid, batch_size, workers, device, log_frequency, - mutator, callbacks): + def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs, + dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device self.model = model + self.mutator = mutator self.loss = loss + self.metrics = metrics self.optimizer = optimizer - self.mutator = mutator self.model.to(self.device) - self.loss.to(self.device) self.mutator.to(self.device) + self.loss.to(self.device) self.num_epochs = num_epochs self.dataset_train = dataset_train @@ -38,7 +53,7 @@ def train_one_epoch(self, epoch): def validate_one_epoch(self, epoch): pass - def _train(self, validate): + def train(self, validate=True): for epoch in range(self.num_epochs): for callback in self.callbacks: callback.on_epoch_begin(epoch) @@ -55,11 +70,13 @@ def _train(self, validate): for callback in self.callbacks: callback.on_epoch_end(epoch) - def train_and_validate(self): - self._train(True) - - def train(self): - self._train(False) - def validate(self): self.validate_one_epoch(-1) + + def export(self, file): + mutator_export = self.mutator.export() + with open(file, "w") as f: + json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder) + + def checkpoint(self): + raise NotImplementedError("Not implemented yet") diff --git a/src/sdk/pynni/nni/nas/pytorch/utils.py b/src/sdk/pynni/nni/nas/pytorch/utils.py new file mode 100644 index 0000000000..d3a4292155 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/utils.py @@ -0,0 +1,107 @@ +from collections import OrderedDict + +_counter = 0 + + +def global_mutable_counting(): + global _counter + _counter += 1 + return _counter + + +class AverageMeterGroup: + + def __init__(self): + self.meters = OrderedDict() + + def update(self, data): + for k, v in data.items(): + if k not in self.meters: + self.meters[k] = AverageMeter(k, ":4f") + self.meters[k].update(v) + + def __str__(self): + return " ".join(str(v) for _, v in self.meters.items()) + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class StructuredMutableTreeNode: + """ + A structured representation of a search space. + A search space comes with a root (with `None` stored in its `mutable`), and a bunch of children in its `children`. + This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet, + the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a + ``Mutable`` (other than ``MutableScope``). + """ + + def __init__(self, mutable): + self.mutable = mutable + self.children = [] + + def add_child(self, mutable): + self.children.append(StructuredMutableTreeNode(mutable)) + return self.children[-1] + + def type(self): + return type(self.mutable) + + def __iter__(self): + return self.traverse() + + def traverse(self, order="pre", deduplicate=True, memo=None): + """ + Return a generator that generates a list of mutables in this tree. + + Parameters + ---------- + order: str + pre or post. If pre, current mutable is yield before children. Otherwise after. + deduplicate: bool + If true, mutables with the same key will not appear after the first appearance. + memo: dict + An auxiliary variable to make deduplicate happen. + + Returns + ------- + generator of Mutable + """ + if memo is None: + memo = set() + assert order in ["pre", "post"] + if order == "pre": + if self.mutable is not None: + if not deduplicate or self.mutable.key not in memo: + memo.add(self.mutable.key) + yield self.mutable + for child in self.children: + for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo): + yield m + if order == "post": + if self.mutable is not None: + if not deduplicate or self.mutable.key not in memo: + memo.add(self.mutable.key) + yield self.mutable