From 6f256c781a5c2cc841060435d3b732daa6adb108 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Tue, 24 Dec 2019 11:52:56 +0800 Subject: [PATCH] Single Path One Shot (#1849) --- examples/nas/.gitignore | 1 + examples/nas/spos/README.md | 88 +++++++ examples/nas/spos/architecture_final.json | 22 ++ examples/nas/spos/blocks.py | 89 +++++++ examples/nas/spos/config_search.yml | 16 ++ examples/nas/spos/dataloader.py | 106 +++++++++ examples/nas/spos/network.py | 156 ++++++++++++ examples/nas/spos/scratch.py | 128 ++++++++++ examples/nas/spos/supernet.py | 74 ++++++ examples/nas/spos/tester.py | 115 +++++++++ examples/nas/spos/tuner.py | 25 ++ examples/nas/spos/utils.py | 41 ++++ .../pynni/nni/nas/pytorch/spos/__init__.py | 6 + .../pynni/nni/nas/pytorch/spos/evolution.py | 222 ++++++++++++++++++ src/sdk/pynni/nni/nas/pytorch/spos/mutator.py | 63 +++++ src/sdk/pynni/nni/nas/pytorch/spos/trainer.py | 63 +++++ 16 files changed, 1215 insertions(+) create mode 100644 examples/nas/spos/README.md create mode 100644 examples/nas/spos/architecture_final.json create mode 100644 examples/nas/spos/blocks.py create mode 100644 examples/nas/spos/config_search.yml create mode 100644 examples/nas/spos/dataloader.py create mode 100644 examples/nas/spos/network.py create mode 100644 examples/nas/spos/scratch.py create mode 100644 examples/nas/spos/supernet.py create mode 100644 examples/nas/spos/tester.py create mode 100644 examples/nas/spos/tuner.py create mode 100644 examples/nas/spos/utils.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/spos/__init__.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/spos/evolution.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/spos/mutator.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/spos/trainer.py diff --git a/examples/nas/.gitignore b/examples/nas/.gitignore index 8eeb0c2a3f..e26f9a17a1 100644 --- a/examples/nas/.gitignore +++ b/examples/nas/.gitignore @@ -1,3 +1,4 @@ data checkpoints runs +nni_auto_gen_search_space.json diff --git a/examples/nas/spos/README.md b/examples/nas/spos/README.md new file mode 100644 index 0000000000..ed239f30a1 --- /dev/null +++ b/examples/nas/spos/README.md @@ -0,0 +1,88 @@ +# Single Path One-Shot Neural Architecture Search with Uniform Sampling + +Single Path One-Shot by Megvii Research. [Paper link](https://arxiv.org/abs/1904.00420). [Official repo](https://github.com/megvii-model/SinglePathOneShot). + +Block search only. Channel search is not supported yet. + +Only GPU version is provided here. + +## Preparation + +### Requirements + +* PyTorch >= 1.2 +* NVIDIA DALI >= 0.16 as we use DALI to accelerate the data loading of ImageNet. [Installation guide](https://docs.nvidia.com/deeplearning/sdk/dali-developer-guide/docs/installation.html) + +### Data + +Need to download the flops lookup table from [here](https://1drv.ms/u/s!Am_mmG2-KsrnajesvSdfsq_cN48?e=aHVppN). +Put `op_flops_dict.pkl` and `checkpoint-150000.pth.tar` (if you don't want to retrain the supernet) under `data` directory. + +Prepare ImageNet in the standard format (follow the script [here](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4)). Link it to `data/imagenet` will be more convenient. + +After preparation, it's expected to have the following code structure: + +``` +spos +├── architecture_final.json +├── blocks.py +├── config_search.yml +├── data +│   ├── imagenet +│   │   ├── train +│   │   └── val +│   └── op_flops_dict.pkl +├── dataloader.py +├── network.py +├── readme.md +├── scratch.py +├── supernet.py +├── tester.py +├── tuner.py +└── utils.py +``` + +## Step 1. Train Supernet + +``` +python supernet.py +``` + +Will export the checkpoint to checkpoints directory, for the next step. + +NOTE: The data loading used in the official repo is [slightly different from usual](https://github.com/megvii-model/SinglePathOneShot/issues/5), as they use BGR tensor and keep the values between 0 and 255 intentionally to align with their own DL framework. The option `--spos-preprocessing` will simulate the behavior used originally and enable you to use the checkpoints pretrained. + +## Step 2. Evolution Search + +Single Path One-Shot leverages evolution algorithm to search for the best architecture. The tester, which is responsible for testing the sampled architecture, recalculates all the batch norm for a subset of training images, and evaluates the architecture on the full validation set. + +To have a search space ready for NNI framework, first run + +``` +nnictl ss_gen -t "python tester.py" +``` + +This will generate a file called `nni_auto_gen_search_space.json`, which is a serialized representation of your search space. + +Then search with evolution tuner. + +``` +nnictl create --config config_search.yml +``` + +The final architecture exported from every epoch of evolution can be found in `checkpoints` under the working directory of your tuner, which, by default, is `$HOME/nni/experiments/your_experiment_id/log`. + +## Step 3. Train from Scratch + +``` +python scratch.py +``` + +By default, it will use `architecture_final.json`. This architecture is provided by the official repo (converted into NNI format). You can use any architecture (e.g., the architecture found in step 2) with `--fixed-arc` option. + +## Current Reproduction Results + +Reproduction is still undergoing. Due to the gap between official release and original paper, we compare our current results with official repo (our run) and paper. + +* Evolution phase is almost aligned with official repo. Our evolution algorithm shows a converging trend and reaches ~65% accuracy at the end of search. Nevertheless, this result is not on par with paper. For details, please refer to [this issue](https://github.com/megvii-model/SinglePathOneShot/issues/6). +* Retrain phase is not aligned. Our retraining code, which uses the architecture released by the authors, reaches 72.14% accuracy, still having a gap towards 73.61% by official release and 74.3% reported in original paper. diff --git a/examples/nas/spos/architecture_final.json b/examples/nas/spos/architecture_final.json new file mode 100644 index 0000000000..512a73b9d6 --- /dev/null +++ b/examples/nas/spos/architecture_final.json @@ -0,0 +1,22 @@ +{ + "LayerChoice1": [false, false, true, false], + "LayerChoice2": [false, true, false, false], + "LayerChoice3": [true, false, false, false], + "LayerChoice4": [false, true, false, false], + "LayerChoice5": [false, false, true, false], + "LayerChoice6": [true, false, false, false], + "LayerChoice7": [false, false, true, false], + "LayerChoice8": [true, false, false, false], + "LayerChoice9": [false, false, true, false], + "LayerChoice10": [true, false, false, false], + "LayerChoice11": [false, false, true, false], + "LayerChoice12": [false, false, false, true], + "LayerChoice13": [true, false, false, false], + "LayerChoice14": [true, false, false, false], + "LayerChoice15": [true, false, false, false], + "LayerChoice16": [true, false, false, false], + "LayerChoice17": [false, false, false, true], + "LayerChoice18": [false, false, true, false], + "LayerChoice19": [false, false, false, true], + "LayerChoice20": [false, false, false, true] +} diff --git a/examples/nas/spos/blocks.py b/examples/nas/spos/blocks.py new file mode 100644 index 0000000000..5908ecf077 --- /dev/null +++ b/examples/nas/spos/blocks.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn as nn + + +class ShuffleNetBlock(nn.Module): + """ + When stride = 1, the block receives input with 2 * inp channels. Otherwise inp channels. + """ + + def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp"): + super().__init__() + assert stride in [1, 2] + assert ksize in [3, 5, 7] + self.channels = inp // 2 if stride == 1 else inp + self.inp = inp + self.oup = oup + self.mid_channels = mid_channels + self.ksize = ksize + self.stride = stride + self.pad = ksize // 2 + self.oup_main = oup - self.channels + assert self.oup_main > 0 + + self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence)) + + if stride == 2: + self.branch_proj = nn.Sequential( + # dw + nn.Conv2d(self.channels, self.channels, ksize, stride, self.pad, + groups=self.channels, bias=False), + nn.BatchNorm2d(self.channels, affine=False), + # pw-linear + nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(self.channels, affine=False), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + if self.stride == 2: + x_proj, x = self.branch_proj(x), x + else: + x_proj, x = self._channel_shuffle(x) + return torch.cat((x_proj, self.branch_main(x)), 1) + + def _decode_point_depth_conv(self, sequence): + result = [] + first_depth = first_point = True + pc = c = self.channels + for i, token in enumerate(sequence): + # compute output channels of this conv + if i + 1 == len(sequence): + assert token == "p", "Last conv must be point-wise conv." + c = self.oup_main + elif token == "p" and first_point: + c = self.mid_channels + if token == "d": + # depth-wise conv + assert pc == c, "Depth-wise conv must not change channels." + result.append(nn.Conv2d(pc, c, self.ksize, self.stride if first_depth else 1, self.pad, + groups=c, bias=False)) + result.append(nn.BatchNorm2d(c, affine=False)) + first_depth = False + elif token == "p": + # point-wise conv + result.append(nn.Conv2d(pc, c, 1, 1, 0, bias=False)) + result.append(nn.BatchNorm2d(c, affine=False)) + result.append(nn.ReLU(inplace=True)) + first_point = False + else: + raise ValueError("Conv sequence must be d and p.") + pc = c + return result + + def _channel_shuffle(self, x): + bs, num_channels, height, width = x.data.size() + assert (num_channels % 4 == 0) + x = x.reshape(bs * num_channels // 2, 2, height * width) + x = x.permute(1, 0, 2) + x = x.reshape(2, -1, num_channels // 2, height, width) + return x[0], x[1] + + +class ShuffleXceptionBlock(ShuffleNetBlock): + + def __init__(self, inp, oup, mid_channels, stride): + super().__init__(inp, oup, mid_channels, 3, stride, "dpdpdp") diff --git a/examples/nas/spos/config_search.yml b/examples/nas/spos/config_search.yml new file mode 100644 index 0000000000..fe27faefc8 --- /dev/null +++ b/examples/nas/spos/config_search.yml @@ -0,0 +1,16 @@ +authorName: unknown +experimentName: SPOS Search +trialConcurrency: 4 +maxExecDuration: 7d +maxTrialNum: 99999 +trainingServicePlatform: local +searchSpacePath: nni_auto_gen_search_space.json +useAnnotation: false +tuner: + codeDir: . + classFileName: tuner.py + className: EvolutionWithFlops +trial: + command: python tester.py --imagenet-dir /path/to/your/imagenet --spos-prep + codeDir: . + gpuNum: 1 diff --git a/examples/nas/spos/dataloader.py b/examples/nas/spos/dataloader.py new file mode 100644 index 0000000000..198d637ed1 --- /dev/null +++ b/examples/nas/spos/dataloader.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os + +import nvidia.dali.ops as ops +import nvidia.dali.types as types +import torch.utils.data +from nvidia.dali.pipeline import Pipeline +from nvidia.dali.plugin.pytorch import DALIClassificationIterator + + +class HybridTrainPipe(Pipeline): + def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed=12, local_rank=0, world_size=1, + spos_pre=False): + super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id) + color_space_type = types.BGR if spos_pre else types.RGB + self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, random_shuffle=True) + self.decode = ops.ImageDecoder(device="mixed", output_type=color_space_type) + self.res = ops.RandomResizedCrop(device="gpu", size=crop, + interp_type=types.INTERP_LINEAR if spos_pre else types.INTERP_TRIANGULAR) + self.twist = ops.ColorTwist(device="gpu") + self.jitter_rng = ops.Uniform(range=[0.6, 1.4]) + self.cmnp = ops.CropMirrorNormalize(device="gpu", + output_dtype=types.FLOAT, + output_layout=types.NCHW, + image_type=color_space_type, + mean=0. if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255], + std=1. if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255]) + self.coin = ops.CoinFlip(probability=0.5) + + def define_graph(self): + rng = self.coin() + self.jpegs, self.labels = self.input(name="Reader") + images = self.decode(self.jpegs) + images = self.res(images) + images = self.twist(images, saturation=self.jitter_rng(), + contrast=self.jitter_rng(), brightness=self.jitter_rng()) + output = self.cmnp(images, mirror=rng) + return [output, self.labels] + + +class HybridValPipe(Pipeline): + def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, seed=12, local_rank=0, world_size=1, + spos_pre=False, shuffle=False): + super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id) + color_space_type = types.BGR if spos_pre else types.RGB + self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, + random_shuffle=shuffle) + self.decode = ops.ImageDecoder(device="mixed", output_type=color_space_type) + self.res = ops.Resize(device="gpu", resize_shorter=size, + interp_type=types.INTERP_LINEAR if spos_pre else types.INTERP_TRIANGULAR) + self.cmnp = ops.CropMirrorNormalize(device="gpu", + output_dtype=types.FLOAT, + output_layout=types.NCHW, + crop=(crop, crop), + image_type=color_space_type, + mean=0. if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255], + std=1. if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255]) + + def define_graph(self): + self.jpegs, self.labels = self.input(name="Reader") + images = self.decode(self.jpegs) + images = self.res(images) + output = self.cmnp(images) + return [output, self.labels] + + +class ClassificationWrapper: + def __init__(self, loader, size): + self.loader = loader + self.size = size + + def __iter__(self): + return self + + def __next__(self): + data = next(self.loader) + return data[0]["data"], data[0]["label"].view(-1).long().cuda(non_blocking=True) + + def __len__(self): + return self.size + + +def get_imagenet_iter_dali(split, image_dir, batch_size, num_threads, crop=224, val_size=256, + spos_preprocessing=False, seed=12, shuffle=False, device_id=None): + world_size, local_rank = 1, 0 + if device_id is None: + device_id = torch.cuda.device_count() - 1 # use last gpu + if split == "train": + pipeline = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id, + data_dir=os.path.join(image_dir, "train"), seed=seed, + crop=crop, world_size=world_size, local_rank=local_rank, + spos_pre=spos_preprocessing) + elif split == "val": + pipeline = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id, + data_dir=os.path.join(image_dir, "val"), seed=seed, + crop=crop, size=val_size, world_size=world_size, local_rank=local_rank, + spos_pre=spos_preprocessing, shuffle=shuffle) + else: + raise AssertionError + pipeline.build() + num_samples = pipeline.epoch_size("Reader") + return ClassificationWrapper( + DALIClassificationIterator(pipeline, size=num_samples, fill_last_batch=split == "train", + auto_reset=True), (num_samples + batch_size - 1) // batch_size) diff --git a/examples/nas/spos/network.py b/examples/nas/spos/network.py new file mode 100644 index 0000000000..ba45095775 --- /dev/null +++ b/examples/nas/spos/network.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import pickle +import re + +import torch +import torch.nn as nn +from nni.nas.pytorch import mutables + +from blocks import ShuffleNetBlock, ShuffleXceptionBlock + + +class ShuffleNetV2OneShot(nn.Module): + block_keys = [ + 'shufflenet_3x3', + 'shufflenet_5x5', + 'shufflenet_7x7', + 'xception_3x3', + ] + + def __init__(self, input_size=224, first_conv_channels=16, last_conv_channels=1024, n_classes=1000, + op_flops_path="./data/op_flops_dict.pkl"): + super().__init__() + + assert input_size % 32 == 0 + with open(os.path.join(os.path.dirname(__file__), op_flops_path), "rb") as fp: + self._op_flops_dict = pickle.load(fp) + + self.stage_blocks = [4, 4, 8, 4] + self.stage_channels = [64, 160, 320, 640] + self._parsed_flops = dict() + self._input_size = input_size + self._feature_map_size = input_size + self._first_conv_channels = first_conv_channels + self._last_conv_channels = last_conv_channels + self._n_classes = n_classes + + # building first layer + self.first_conv = nn.Sequential( + nn.Conv2d(3, first_conv_channels, 3, 2, 1, bias=False), + nn.BatchNorm2d(first_conv_channels, affine=False), + nn.ReLU(inplace=True), + ) + self._feature_map_size //= 2 + + p_channels = first_conv_channels + features = [] + for num_blocks, channels in zip(self.stage_blocks, self.stage_channels): + features.extend(self._make_blocks(num_blocks, p_channels, channels)) + p_channels = channels + self.features = nn.Sequential(*features) + + self.conv_last = nn.Sequential( + nn.Conv2d(p_channels, last_conv_channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(last_conv_channels, affine=False), + nn.ReLU(inplace=True), + ) + self.globalpool = nn.AvgPool2d(self._feature_map_size) + self.dropout = nn.Dropout(0.1) + self.classifier = nn.Sequential( + nn.Linear(last_conv_channels, n_classes, bias=False), + ) + + self._initialize_weights() + + def _make_blocks(self, blocks, in_channels, channels): + result = [] + for i in range(blocks): + stride = 2 if i == 0 else 1 + inp = in_channels if i == 0 else channels + oup = channels + + base_mid_channels = channels // 2 + mid_channels = int(base_mid_channels) # prepare for scale + choice_block = mutables.LayerChoice([ + ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride), + ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride), + ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride), + ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride) + ]) + result.append(choice_block) + + # find the corresponding flops + flop_key = (inp, oup, mid_channels, self._feature_map_size, self._feature_map_size, stride) + self._parsed_flops[choice_block.key] = [ + self._op_flops_dict["{}_stride_{}".format(k, stride)][flop_key] for k in self.block_keys + ] + if stride == 2: + self._feature_map_size //= 2 + return result + + def forward(self, x): + bs = x.size(0) + x = self.first_conv(x) + x = self.features(x) + x = self.conv_last(x) + x = self.globalpool(x) + + x = self.dropout(x) + x = x.contiguous().view(bs, -1) + x = self.classifier(x) + return x + + def get_candidate_flops(self, candidate): + conv1_flops = self._op_flops_dict["conv1"][(3, self._first_conv_channels, + self._input_size, self._input_size, 2)] + # Should use `last_conv_channels` here, but megvii insists that it's `n_classes`. Keeping it. + # https://github.com/megvii-model/SinglePathOneShot/blob/36eed6cf083497ffa9cfe7b8da25bb0b6ba5a452/src/Supernet/flops.py#L313 + rest_flops = self._op_flops_dict["rest_operation"][(self.stage_channels[-1], self._n_classes, + self._feature_map_size, self._feature_map_size, 1)] + total_flops = conv1_flops + rest_flops + for k, m in candidate.items(): + parsed_flops_dict = self._parsed_flops[k] + if isinstance(m, dict): # to be compatible with classical nas format + total_flops += parsed_flops_dict[m["_idx"]] + else: + total_flops += parsed_flops_dict[torch.max(m, 0)[1]] + return total_flops + + def _initialize_weights(self): + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + if 'first' in name: + nn.init.normal_(m.weight, 0, 0.01) + else: + nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1]) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0001) + nn.init.constant_(m.running_mean, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0001) + nn.init.constant_(m.running_mean, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"): + checkpoint = torch.load(filepath, map_location=torch.device("cpu")) + result = dict() + for k, v in checkpoint["state_dict"].items(): + if k.startswith("module."): + k = k[len("module."):] + k = re.sub(r"^(features.\d+).(\d+)", "\\1.choices.\\2", k) + result[k] = v + return result diff --git a/examples/nas/spos/scratch.py b/examples/nas/spos/scratch.py new file mode 100644 index 0000000000..3a944a7909 --- /dev/null +++ b/examples/nas/spos/scratch.py @@ -0,0 +1,128 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import argparse +import logging +import random + +import numpy as np +import torch +import torch.nn as nn +from dataloader import get_imagenet_iter_dali +from nni.nas.pytorch.fixed import apply_fixed_architecture +from nni.nas.pytorch.utils import AverageMeterGroup +from torch.utils.tensorboard import SummaryWriter + +from network import ShuffleNetV2OneShot +from utils import CrossEntropyLabelSmooth, accuracy + +logger = logging.getLogger("nni.spos.scratch") + + +def train(epoch, model, criterion, optimizer, loader, writer, args): + model.train() + meters = AverageMeterGroup() + cur_lr = optimizer.param_groups[0]["lr"] + + for step, (x, y) in enumerate(loader): + cur_step = len(loader) * epoch + step + optimizer.zero_grad() + logits = model(x) + loss = criterion(logits, y) + loss.backward() + optimizer.step() + + metrics = accuracy(logits, y) + metrics["loss"] = loss.item() + meters.update(metrics) + + writer.add_scalar("lr", cur_lr, global_step=cur_step) + writer.add_scalar("loss/train", loss.item(), global_step=cur_step) + writer.add_scalar("acc1/train", metrics["acc1"], global_step=cur_step) + writer.add_scalar("acc5/train", metrics["acc5"], global_step=cur_step) + + if step % args.log_frequency == 0 or step + 1 == len(loader): + logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, + args.epochs, step + 1, len(loader), meters) + + logger.info("Epoch %d training summary: %s", epoch + 1, meters) + + +def validate(epoch, model, criterion, loader, writer, args): + model.eval() + meters = AverageMeterGroup() + with torch.no_grad(): + for step, (x, y) in enumerate(loader): + logits = model(x) + loss = criterion(logits, y) + metrics = accuracy(logits, y) + metrics["loss"] = loss.item() + meters.update(metrics) + + if step % args.log_frequency == 0 or step + 1 == len(loader): + logger.info("Epoch [%d/%d] Validation Step [%d/%d] %s", epoch + 1, + args.epochs, step + 1, len(loader), meters) + + writer.add_scalar("loss/test", meters.loss.avg, global_step=epoch) + writer.add_scalar("acc1/test", meters.acc1.avg, global_step=epoch) + writer.add_scalar("acc5/test", meters.acc5.avg, global_step=epoch) + + logger.info("Epoch %d validation: top1 = %f, top5 = %f", epoch + 1, meters.acc1.avg, meters.acc5.avg) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("SPOS Training From Scratch") + parser.add_argument("--imagenet-dir", type=str, default="./data/imagenet") + parser.add_argument("--tb-dir", type=str, default="runs") + parser.add_argument("--architecture", type=str, default="architecture_final.json") + parser.add_argument("--workers", type=int, default=12) + parser.add_argument("--batch-size", type=int, default=1024) + parser.add_argument("--epochs", type=int, default=240) + parser.add_argument("--learning-rate", type=float, default=0.5) + parser.add_argument("--momentum", type=float, default=0.9) + parser.add_argument("--weight-decay", type=float, default=4E-5) + parser.add_argument("--label-smooth", type=float, default=0.1) + parser.add_argument("--log-frequency", type=int, default=10) + parser.add_argument("--lr-decay", type=str, default="linear") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--spos-preprocessing", default=False, action="store_true") + parser.add_argument("--label-smoothing", type=float, default=0.1) + + args = parser.parse_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + torch.backends.cudnn.deterministic = True + + model = ShuffleNetV2OneShot() + model.cuda() + apply_fixed_architecture(model, args.architecture) + if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu + model = nn.DataParallel(model, device_ids=list(range(0, torch.cuda.device_count() - 1))) + criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing) + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, + momentum=args.momentum, weight_decay=args.weight_decay) + if args.lr_decay == "linear": + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, + lambda step: (1.0 - step / args.epochs) + if step <= args.epochs else 0, + last_epoch=-1) + elif args.lr_decay == "cosine": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, 1E-3) + else: + raise ValueError("'%s' not supported." % args.lr_decay) + writer = SummaryWriter(log_dir=args.tb_dir) + + train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.batch_size, args.workers, + spos_preprocessing=args.spos_preprocessing) + val_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.batch_size, args.workers, + spos_preprocessing=args.spos_preprocessing) + + for epoch in range(args.epochs): + train(epoch, model, criterion, optimizer, train_loader, writer, args) + validate(epoch, model, criterion, val_loader, writer, args) + scheduler.step() + + writer.close() diff --git a/examples/nas/spos/supernet.py b/examples/nas/spos/supernet.py new file mode 100644 index 0000000000..3ab717868c --- /dev/null +++ b/examples/nas/spos/supernet.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import argparse +import logging +import random + +import numpy as np +import torch +import torch.nn as nn +from nni.nas.pytorch.callbacks import LRSchedulerCallback +from nni.nas.pytorch.callbacks import ModelCheckpoint +from nni.nas.pytorch.spos import SPOSSupernetTrainingMutator, SPOSSupernetTrainer + +from dataloader import get_imagenet_iter_dali +from network import ShuffleNetV2OneShot, load_and_parse_state_dict +from utils import CrossEntropyLabelSmooth, accuracy + +logger = logging.getLogger("nni.spos.supernet") + +if __name__ == "__main__": + parser = argparse.ArgumentParser("SPOS Supernet Training") + parser.add_argument("--imagenet-dir", type=str, default="./data/imagenet") + parser.add_argument("--load-checkpoint", action="store_true", default=False) + parser.add_argument("--spos-preprocessing", action="store_true", default=False, + help="When true, image values will range from 0 to 255 and use BGR " + "(as in original repo).") + parser.add_argument("--workers", type=int, default=4) + parser.add_argument("--batch-size", type=int, default=768) + parser.add_argument("--epochs", type=int, default=120) + parser.add_argument("--learning-rate", type=float, default=0.5) + parser.add_argument("--momentum", type=float, default=0.9) + parser.add_argument("--weight-decay", type=float, default=4E-5) + parser.add_argument("--label-smooth", type=float, default=0.1) + parser.add_argument("--log-frequency", type=int, default=10) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--label-smoothing", type=float, default=0.1) + + args = parser.parse_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + torch.backends.cudnn.deterministic = True + + model = ShuffleNetV2OneShot() + if args.load_checkpoint: + if not args.spos_preprocessing: + logger.warning("You might want to use SPOS preprocessing if you are loading their checkpoints.") + model.load_state_dict(load_and_parse_state_dict()) + model.cuda() + if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu + model = nn.DataParallel(model, device_ids=list(range(0, torch.cuda.device_count() - 1))) + mutator = SPOSSupernetTrainingMutator(model, flops_func=model.module.get_candidate_flops, + flops_lb=290E6, flops_ub=360E6) + criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing) + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, + momentum=args.momentum, weight_decay=args.weight_decay) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, + lambda step: (1.0 - step / args.epochs) + if step <= args.epochs else 0, + last_epoch=-1) + train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.batch_size, args.workers, + spos_preprocessing=args.spos_preprocessing) + valid_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.batch_size, args.workers, + spos_preprocessing=args.spos_preprocessing) + trainer = SPOSSupernetTrainer(model, criterion, accuracy, optimizer, + args.epochs, train_loader, valid_loader, + mutator=mutator, batch_size=args.batch_size, + log_frequency=args.log_frequency, workers=args.workers, + callbacks=[LRSchedulerCallback(scheduler), + ModelCheckpoint("./checkpoints")]) + trainer.train() diff --git a/examples/nas/spos/tester.py b/examples/nas/spos/tester.py new file mode 100644 index 0000000000..b31b8f2fab --- /dev/null +++ b/examples/nas/spos/tester.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import argparse +import logging +import random +import time +from itertools import cycle + +import nni +import numpy as np +import torch +import torch.nn as nn +from nni.nas.pytorch.classic_nas import get_and_apply_next_architecture +from nni.nas.pytorch.utils import AverageMeterGroup + +from dataloader import get_imagenet_iter_dali +from network import ShuffleNetV2OneShot, load_and_parse_state_dict +from utils import CrossEntropyLabelSmooth, accuracy + +logger = logging.getLogger("nni.spos.tester") + + +def retrain_bn(model, criterion, max_iters, log_freq, loader): + with torch.no_grad(): + logger.info("Clear BN statistics...") + for m in model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.running_mean = torch.zeros_like(m.running_mean) + m.running_var = torch.ones_like(m.running_var) + + logger.info("Train BN with training set (BN sanitize)...") + model.train() + meters = AverageMeterGroup() + for step in range(max_iters): + inputs, targets = next(loader) + logits = model(inputs) + loss = criterion(logits, targets) + metrics = accuracy(logits, targets) + metrics["loss"] = loss.item() + meters.update(metrics) + if step % log_freq == 0 or step + 1 == max_iters: + logger.info("Train Step [%d/%d] %s", step + 1, max_iters, meters) + + +def test_acc(model, criterion, log_freq, loader): + logger.info("Start testing...") + model.eval() + meters = AverageMeterGroup() + start_time = time.time() + with torch.no_grad(): + for step, (inputs, targets) in enumerate(loader): + logits = model(inputs) + loss = criterion(logits, targets) + metrics = accuracy(logits, targets) + metrics["loss"] = loss.item() + meters.update(metrics) + if step % log_freq == 0 or step + 1 == len(loader): + logger.info("Valid Step [%d/%d] time %.3fs acc1 %.4f acc5 %.4f loss %.4f", + step + 1, len(loader), time.time() - start_time, + meters.acc1.avg, meters.acc5.avg, meters.loss.avg) + return meters.acc1.avg + + +def evaluate_acc(model, criterion, args, loader_train, loader_test): + acc_before = test_acc(model, criterion, args.log_frequency, loader_test) + nni.report_intermediate_result(acc_before) + + retrain_bn(model, criterion, args.train_iters, args.log_frequency, loader_train) + acc = test_acc(model, criterion, args.log_frequency, loader_test) + assert isinstance(acc, float) + nni.report_intermediate_result(acc) + nni.report_final_result(acc) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("SPOS Candidate Tester") + parser.add_argument("--imagenet-dir", type=str, default="./data/imagenet") + parser.add_argument("--checkpoint", type=str, default="./data/checkpoint-150000.pth.tar") + parser.add_argument("--spos-preprocessing", action="store_true", default=False, + help="When true, image values will range from 0 to 255 and use BGR " + "(as in original repo).") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--workers", type=int, default=6) + parser.add_argument("--train-batch-size", type=int, default=128) + parser.add_argument("--train-iters", type=int, default=200) + parser.add_argument("--test-batch-size", type=int, default=512) + parser.add_argument("--log-frequency", type=int, default=10) + + args = parser.parse_args() + + # use a fixed set of image will improve the performance + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + torch.backends.cudnn.deterministic = True + + assert torch.cuda.is_available() + + model = ShuffleNetV2OneShot() + criterion = CrossEntropyLabelSmooth(1000, 0.1) + get_and_apply_next_architecture(model) + model.load_state_dict(load_and_parse_state_dict(filepath=args.checkpoint)) + model.cuda() + + train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.train_batch_size, args.workers, + spos_preprocessing=args.spos_preprocessing, + seed=args.seed, device_id=0) + val_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.test_batch_size, args.workers, + spos_preprocessing=args.spos_preprocessing, shuffle=True, + seed=args.seed, device_id=0) + train_loader = cycle(train_loader) + + evaluate_acc(model, criterion, args, train_loader, val_loader) diff --git a/examples/nas/spos/tuner.py b/examples/nas/spos/tuner.py new file mode 100644 index 0000000000..fb8b9f2aa4 --- /dev/null +++ b/examples/nas/spos/tuner.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from nni.nas.pytorch.spos import SPOSEvolution + +from network import ShuffleNetV2OneShot + + +class EvolutionWithFlops(SPOSEvolution): + """ + This tuner extends the function of evolution tuner, by limiting the flops generated by tuner. + Needs a function to examine the flops. + """ + + def __init__(self, flops_limit=330E6, **kwargs): + super().__init__(**kwargs) + self.model = ShuffleNetV2OneShot() + self.flops_limit = flops_limit + + def _is_legal(self, cand): + if not super()._is_legal(cand): + return False + if self.model.get_candidate_flops(cand) > self.flops_limit: + return False + return True diff --git a/examples/nas/spos/utils.py b/examples/nas/spos/utils.py new file mode 100644 index 0000000000..70ad98b55f --- /dev/null +++ b/examples/nas/spos/utils.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn as nn + + +class CrossEntropyLabelSmooth(nn.Module): + + def __init__(self, num_classes, epsilon): + super(CrossEntropyLabelSmooth, self).__init__() + self.num_classes = num_classes + self.epsilon = epsilon + self.logsoftmax = nn.LogSoftmax(dim=1) + + def forward(self, inputs, targets): + log_probs = self.logsoftmax(inputs) + targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) + targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes + loss = (-targets * log_probs).mean(0).sum() + return loss + + +def accuracy(output, target, topk=(1, 5)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = dict() + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() + return res diff --git a/src/sdk/pynni/nni/nas/pytorch/spos/__init__.py b/src/sdk/pynni/nni/nas/pytorch/spos/__init__.py new file mode 100644 index 0000000000..ed432b0845 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/spos/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .evolution import SPOSEvolution +from .mutator import SPOSSupernetTrainingMutator +from .trainer import SPOSSupernetTrainer diff --git a/src/sdk/pynni/nni/nas/pytorch/spos/evolution.py b/src/sdk/pynni/nni/nas/pytorch/spos/evolution.py new file mode 100644 index 0000000000..3541c81fd7 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/spos/evolution.py @@ -0,0 +1,222 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import logging +import os +import re +from collections import deque + +import numpy as np +from nni.tuner import Tuner +from nni.nas.pytorch.classic_nas.mutator import LAYER_CHOICE, INPUT_CHOICE + + +_logger = logging.getLogger(__name__) + + +class SPOSEvolution(Tuner): + + def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1, + num_crossover=25, num_mutation=25): + """ + Initialize SPOS Evolution Tuner. + + Parameters + ---------- + max_epochs : int + Maximum number of epochs to run. + num_select : int + Number of survival candidates of each epoch. + num_population : int + Number of candidates at the start of each epoch. If candidates generated by + crossover and mutation are not enough, the rest will be filled with random + candidates. + m_prob : float + The probability of mutation. + num_crossover : int + Number of candidates generated by crossover in each epoch. + num_mutation : int + Number of candidates generated by mutation in each epoch. + """ + assert num_population >= num_select + self.max_epochs = max_epochs + self.num_select = num_select + self.num_population = num_population + self.m_prob = m_prob + self.num_crossover = num_crossover + self.num_mutation = num_mutation + self.epoch = 0 + self.candidates = [] + self.search_space = None + self.random_state = np.random.RandomState(0) + + # async status + self._to_evaluate_queue = deque() + self._sending_parameter_queue = deque() + self._pending_result_ids = set() + self._reward_dict = dict() + self._id2candidate = dict() + self._st_callback = None + + def update_search_space(self, search_space): + """ + Handle the initialization/update event of search space. + """ + self._search_space = search_space + self._next_round() + + def _next_round(self): + _logger.info("Epoch %d, generating...", self.epoch) + if self.epoch == 0: + self._get_random_population() + self.export_results(self.candidates) + else: + best_candidates = self._select_top_candidates() + self.export_results(best_candidates) + if self.epoch >= self.max_epochs: + return + self.candidates = self._get_mutation(best_candidates) + self._get_crossover(best_candidates) + self._get_random_population() + self.epoch += 1 + + def _random_candidate(self): + chosen_arch = dict() + for key, val in self._search_space.items(): + if val["_type"] == LAYER_CHOICE: + choices = val["_value"] + index = self.random_state.randint(len(choices)) + chosen_arch[key] = {"_value": choices[index], "_idx": index} + elif val["_type"] == INPUT_CHOICE: + raise NotImplementedError("Input choice is not implemented yet.") + return chosen_arch + + def _add_to_evaluate_queue(self, cand): + _logger.info("Generate candidate %s, adding to eval queue.", self._get_architecture_repr(cand)) + self._reward_dict[self._hashcode(cand)] = 0. + self._to_evaluate_queue.append(cand) + + def _get_random_population(self): + while len(self.candidates) < self.num_population: + cand = self._random_candidate() + if self._is_legal(cand): + _logger.info("Random candidate generated.") + self._add_to_evaluate_queue(cand) + self.candidates.append(cand) + + def _get_crossover(self, best): + result = [] + for _ in range(10 * self.num_crossover): + cand_p1 = best[self.random_state.randint(len(best))] + cand_p2 = best[self.random_state.randint(len(best))] + assert cand_p1.keys() == cand_p2.keys() + cand = {k: cand_p1[k] if self.random_state.randint(2) == 0 else cand_p2[k] + for k in cand_p1.keys()} + if self._is_legal(cand): + result.append(cand) + self._add_to_evaluate_queue(cand) + if len(result) >= self.num_crossover: + break + _logger.info("Found %d architectures with crossover.", len(result)) + return result + + def _get_mutation(self, best): + result = [] + for _ in range(10 * self.num_mutation): + cand = best[self.random_state.randint(len(best))].copy() + mutation_sample = np.random.random_sample(len(cand)) + for s, k in zip(mutation_sample, cand): + if s < self.m_prob: + choices = self._search_space[k]["_value"] + index = self.random_state.randint(len(choices)) + cand[k] = {"_value": choices[index], "_idx": index} + if self._is_legal(cand): + result.append(cand) + self._add_to_evaluate_queue(cand) + if len(result) >= self.num_mutation: + break + _logger.info("Found %d architectures with mutation.", len(result)) + return result + + def _get_architecture_repr(self, cand): + return re.sub(r"\".*?\": \{\"_idx\": (\d+), \"_value\": \".*?\"\}", r"\1", + self._hashcode(cand)) + + def _is_legal(self, cand): + if self._hashcode(cand) in self._reward_dict: + return False + return True + + def _select_top_candidates(self): + reward_query = lambda cand: self._reward_dict[self._hashcode(cand)] + _logger.info("All candidate rewards: %s", list(map(reward_query, self.candidates))) + result = sorted(self.candidates, key=reward_query, reverse=True)[:self.num_select] + _logger.info("Best candidate rewards: %s", list(map(reward_query, result))) + return result + + @staticmethod + def _hashcode(d): + return json.dumps(d, sort_keys=True) + + def _bind_and_send_parameters(self): + """ + There are two types of resources: parameter ids and candidates. This function is called at + necessary times to bind these resources to send new trials with st_callback. + """ + result = [] + while self._sending_parameter_queue and self._to_evaluate_queue: + parameter_id = self._sending_parameter_queue.popleft() + parameters = self._to_evaluate_queue.popleft() + self._id2candidate[parameter_id] = parameters + result.append(parameters) + self._pending_result_ids.add(parameter_id) + self._st_callback(parameter_id, parameters) + _logger.info("Send parameter [%d] %s.", parameter_id, self._get_architecture_repr(parameters)) + return result + + def generate_multiple_parameters(self, parameter_id_list, **kwargs): + """ + Callback function necessary to implement a tuner. This will put more parameter ids into the + parameter id queue. + """ + if "st_callback" in kwargs and self._st_callback is None: + self._st_callback = kwargs["st_callback"] + for parameter_id in parameter_id_list: + self._sending_parameter_queue.append(parameter_id) + self._bind_and_send_parameters() + return [] # always not use this. might induce problem of over-sending + + def receive_trial_result(self, parameter_id, parameters, value, **kwargs): + """ + Callback function. Receive a trial result. + """ + _logger.info("Candidate %d, reported reward %f", parameter_id, value) + self._reward_dict[self._hashcode(self._id2candidate[parameter_id])] = value + + def trial_end(self, parameter_id, success, **kwargs): + """ + Callback function when a trial is ended and resource is released. + """ + self._pending_result_ids.remove(parameter_id) + if not self._pending_result_ids and not self._to_evaluate_queue: + # a new epoch now + self._next_round() + assert self._st_callback is not None + self._bind_and_send_parameters() + + def export_results(self, result): + """ + Export a number of candidates to `checkpoints` dir. + + Parameters + ---------- + result : dict + """ + os.makedirs("checkpoints", exist_ok=True) + for i, cand in enumerate(result): + converted = dict() + for cand_key, cand_val in cand.items(): + onehot = [k == cand_val["_idx"] for k in range(len(self._search_space[cand_key]["_value"]))] + converted[cand_key] = onehot + with open(os.path.join("checkpoints", "%03d_%03d.json" % (self.epoch, i)), "w") as fp: + json.dump(converted, fp) diff --git a/src/sdk/pynni/nni/nas/pytorch/spos/mutator.py b/src/sdk/pynni/nni/nas/pytorch/spos/mutator.py new file mode 100644 index 0000000000..88a01eeeaf --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/spos/mutator.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging + +import numpy as np +from nni.nas.pytorch.random import RandomMutator + +_logger = logging.getLogger(__name__) + + +class SPOSSupernetTrainingMutator(RandomMutator): + def __init__(self, model, flops_func=None, flops_lb=None, flops_ub=None, + flops_bin_num=7, flops_sample_timeout=500): + """ + + Parameters + ---------- + model : nn.Module + flops_func : callable + Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func` + is None, functions related to flops will be deactivated. + flops_lb : number + Lower bound of flops. + flops_ub : number + Upper bound of flops. + flops_bin_num : number + Number of bins divided for the interval of flops to ensure the uniformity. Bigger number will be more + uniform, but the sampling will be slower. + flops_sample_timeout : int + Maximum number of attempts to sample before giving up and use a random candidate. + """ + super().__init__(model) + self._flops_func = flops_func + if self._flops_func is not None: + self._flops_bin_num = flops_bin_num + self._flops_bins = [flops_lb + (flops_ub - flops_lb) / flops_bin_num * i for i in range(flops_bin_num + 1)] + self._flops_sample_timeout = flops_sample_timeout + + def sample_search(self): + """ + Sample a candidate for training. When `flops_func` is not None, candidates will be sampled uniformly + relative to flops. + + Returns + ------- + dict + """ + if self._flops_func is not None: + for times in range(self._flops_sample_timeout): + idx = np.random.randint(self._flops_bin_num) + cand = super().sample_search() + if self._flops_bins[idx] <= self._flops_func(cand) <= self._flops_bins[idx + 1]: + _logger.debug("Sampled candidate flops %f in %d times.", cand, times) + return cand + _logger.warning("Failed to sample a flops-valid candidate within %d tries.", self._flops_sample_timeout) + return super().sample_search() + + def sample_final(self): + """ + Implement only to suffice the interface of Mutator. + """ + return self.sample_search() diff --git a/src/sdk/pynni/nni/nas/pytorch/spos/trainer.py b/src/sdk/pynni/nni/nas/pytorch/spos/trainer.py new file mode 100644 index 0000000000..ab23760bf9 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/spos/trainer.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging + +import torch +from nni.nas.pytorch.trainer import Trainer +from nni.nas.pytorch.utils import AverageMeterGroup + +from .mutator import SPOSSupernetTrainingMutator + +logger = logging.getLogger(__name__) + + +class SPOSSupernetTrainer(Trainer): + """ + This trainer trains a supernet that can be used for evolution search. + """ + + def __init__(self, model, loss, metrics, + optimizer, num_epochs, train_loader, valid_loader, + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, + callbacks=None): + assert torch.cuda.is_available() + super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model), + loss, metrics, optimizer, num_epochs, None, None, + batch_size, workers, device, log_frequency, callbacks) + + self.train_loader = train_loader + self.valid_loader = valid_loader + + def train_one_epoch(self, epoch): + self.model.train() + meters = AverageMeterGroup() + for step, (x, y) in enumerate(self.train_loader): + self.optimizer.zero_grad() + self.mutator.reset() + logits = self.model(x) + loss = self.loss(logits, y) + loss.backward() + self.optimizer.step() + + metrics = self.metrics(logits, y) + metrics["loss"] = loss.item() + meters.update(metrics) + if self.log_frequency is not None and step % self.log_frequency == 0: + logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, + self.num_epochs, step + 1, len(self.train_loader), meters) + + def validate_one_epoch(self, epoch): + self.model.eval() + meters = AverageMeterGroup() + with torch.no_grad(): + for step, (x, y) in enumerate(self.valid_loader): + self.mutator.reset() + logits = self.model(x) + loss = self.loss(logits, y) + metrics = self.metrics(logits, y) + metrics["loss"] = loss.item() + meters.update(metrics) + if self.log_frequency is not None and step % self.log_frequency == 0: + logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1, + self.num_epochs, step + 1, len(self.valid_loader), meters)