Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Single Path One Shot (#1849)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuge Zhang authored Dec 24, 2019
1 parent 4f3ee9c commit 6f256c7
Show file tree
Hide file tree
Showing 16 changed files with 1,215 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/nas/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
data
checkpoints
runs
nni_auto_gen_search_space.json
88 changes: 88 additions & 0 deletions examples/nas/spos/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Single Path One-Shot Neural Architecture Search with Uniform Sampling

Single Path One-Shot by Megvii Research. [Paper link](https://arxiv.org/abs/1904.00420). [Official repo](https://github.com/megvii-model/SinglePathOneShot).

Block search only. Channel search is not supported yet.

Only GPU version is provided here.

## Preparation

### Requirements

* PyTorch >= 1.2
* NVIDIA DALI >= 0.16 as we use DALI to accelerate the data loading of ImageNet. [Installation guide](https://docs.nvidia.com/deeplearning/sdk/dali-developer-guide/docs/installation.html)

### Data

Need to download the flops lookup table from [here](https://1drv.ms/u/s!Am_mmG2-KsrnajesvSdfsq_cN48?e=aHVppN).
Put `op_flops_dict.pkl` and `checkpoint-150000.pth.tar` (if you don't want to retrain the supernet) under `data` directory.

Prepare ImageNet in the standard format (follow the script [here](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4)). Link it to `data/imagenet` will be more convenient.

After preparation, it's expected to have the following code structure:

```
spos
├── architecture_final.json
├── blocks.py
├── config_search.yml
├── data
│   ├── imagenet
│   │   ├── train
│   │   └── val
│   └── op_flops_dict.pkl
├── dataloader.py
├── network.py
├── readme.md
├── scratch.py
├── supernet.py
├── tester.py
├── tuner.py
└── utils.py
```

## Step 1. Train Supernet

```
python supernet.py
```

Will export the checkpoint to checkpoints directory, for the next step.

NOTE: The data loading used in the official repo is [slightly different from usual](https://github.com/megvii-model/SinglePathOneShot/issues/5), as they use BGR tensor and keep the values between 0 and 255 intentionally to align with their own DL framework. The option `--spos-preprocessing` will simulate the behavior used originally and enable you to use the checkpoints pretrained.

## Step 2. Evolution Search

Single Path One-Shot leverages evolution algorithm to search for the best architecture. The tester, which is responsible for testing the sampled architecture, recalculates all the batch norm for a subset of training images, and evaluates the architecture on the full validation set.

To have a search space ready for NNI framework, first run

```
nnictl ss_gen -t "python tester.py"
```

This will generate a file called `nni_auto_gen_search_space.json`, which is a serialized representation of your search space.

Then search with evolution tuner.

```
nnictl create --config config_search.yml
```

The final architecture exported from every epoch of evolution can be found in `checkpoints` under the working directory of your tuner, which, by default, is `$HOME/nni/experiments/your_experiment_id/log`.

## Step 3. Train from Scratch

```
python scratch.py
```

By default, it will use `architecture_final.json`. This architecture is provided by the official repo (converted into NNI format). You can use any architecture (e.g., the architecture found in step 2) with `--fixed-arc` option.

## Current Reproduction Results

Reproduction is still undergoing. Due to the gap between official release and original paper, we compare our current results with official repo (our run) and paper.

* Evolution phase is almost aligned with official repo. Our evolution algorithm shows a converging trend and reaches ~65% accuracy at the end of search. Nevertheless, this result is not on par with paper. For details, please refer to [this issue](https://github.com/megvii-model/SinglePathOneShot/issues/6).
* Retrain phase is not aligned. Our retraining code, which uses the architecture released by the authors, reaches 72.14% accuracy, still having a gap towards 73.61% by official release and 74.3% reported in original paper.
22 changes: 22 additions & 0 deletions examples/nas/spos/architecture_final.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"LayerChoice1": [false, false, true, false],
"LayerChoice2": [false, true, false, false],
"LayerChoice3": [true, false, false, false],
"LayerChoice4": [false, true, false, false],
"LayerChoice5": [false, false, true, false],
"LayerChoice6": [true, false, false, false],
"LayerChoice7": [false, false, true, false],
"LayerChoice8": [true, false, false, false],
"LayerChoice9": [false, false, true, false],
"LayerChoice10": [true, false, false, false],
"LayerChoice11": [false, false, true, false],
"LayerChoice12": [false, false, false, true],
"LayerChoice13": [true, false, false, false],
"LayerChoice14": [true, false, false, false],
"LayerChoice15": [true, false, false, false],
"LayerChoice16": [true, false, false, false],
"LayerChoice17": [false, false, false, true],
"LayerChoice18": [false, false, true, false],
"LayerChoice19": [false, false, false, true],
"LayerChoice20": [false, false, false, true]
}
89 changes: 89 additions & 0 deletions examples/nas/spos/blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
import torch.nn as nn


class ShuffleNetBlock(nn.Module):
"""
When stride = 1, the block receives input with 2 * inp channels. Otherwise inp channels.
"""

def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp"):
super().__init__()
assert stride in [1, 2]
assert ksize in [3, 5, 7]
self.channels = inp // 2 if stride == 1 else inp
self.inp = inp
self.oup = oup
self.mid_channels = mid_channels
self.ksize = ksize
self.stride = stride
self.pad = ksize // 2
self.oup_main = oup - self.channels
assert self.oup_main > 0

self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence))

if stride == 2:
self.branch_proj = nn.Sequential(
# dw
nn.Conv2d(self.channels, self.channels, ksize, stride, self.pad,
groups=self.channels, bias=False),
nn.BatchNorm2d(self.channels, affine=False),
# pw-linear
nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(self.channels, affine=False),
nn.ReLU(inplace=True)
)

def forward(self, x):
if self.stride == 2:
x_proj, x = self.branch_proj(x), x
else:
x_proj, x = self._channel_shuffle(x)
return torch.cat((x_proj, self.branch_main(x)), 1)

def _decode_point_depth_conv(self, sequence):
result = []
first_depth = first_point = True
pc = c = self.channels
for i, token in enumerate(sequence):
# compute output channels of this conv
if i + 1 == len(sequence):
assert token == "p", "Last conv must be point-wise conv."
c = self.oup_main
elif token == "p" and first_point:
c = self.mid_channels
if token == "d":
# depth-wise conv
assert pc == c, "Depth-wise conv must not change channels."
result.append(nn.Conv2d(pc, c, self.ksize, self.stride if first_depth else 1, self.pad,
groups=c, bias=False))
result.append(nn.BatchNorm2d(c, affine=False))
first_depth = False
elif token == "p":
# point-wise conv
result.append(nn.Conv2d(pc, c, 1, 1, 0, bias=False))
result.append(nn.BatchNorm2d(c, affine=False))
result.append(nn.ReLU(inplace=True))
first_point = False
else:
raise ValueError("Conv sequence must be d and p.")
pc = c
return result

def _channel_shuffle(self, x):
bs, num_channels, height, width = x.data.size()
assert (num_channels % 4 == 0)
x = x.reshape(bs * num_channels // 2, 2, height * width)
x = x.permute(1, 0, 2)
x = x.reshape(2, -1, num_channels // 2, height, width)
return x[0], x[1]


class ShuffleXceptionBlock(ShuffleNetBlock):

def __init__(self, inp, oup, mid_channels, stride):
super().__init__(inp, oup, mid_channels, 3, stride, "dpdpdp")
16 changes: 16 additions & 0 deletions examples/nas/spos/config_search.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
authorName: unknown
experimentName: SPOS Search
trialConcurrency: 4
maxExecDuration: 7d
maxTrialNum: 99999
trainingServicePlatform: local
searchSpacePath: nni_auto_gen_search_space.json
useAnnotation: false
tuner:
codeDir: .
classFileName: tuner.py
className: EvolutionWithFlops
trial:
command: python tester.py --imagenet-dir /path/to/your/imagenet --spos-prep
codeDir: .
gpuNum: 1
106 changes: 106 additions & 0 deletions examples/nas/spos/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os

import nvidia.dali.ops as ops
import nvidia.dali.types as types
import torch.utils.data
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator


class HybridTrainPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed=12, local_rank=0, world_size=1,
spos_pre=False):
super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id)
color_space_type = types.BGR if spos_pre else types.RGB
self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, random_shuffle=True)
self.decode = ops.ImageDecoder(device="mixed", output_type=color_space_type)
self.res = ops.RandomResizedCrop(device="gpu", size=crop,
interp_type=types.INTERP_LINEAR if spos_pre else types.INTERP_TRIANGULAR)
self.twist = ops.ColorTwist(device="gpu")
self.jitter_rng = ops.Uniform(range=[0.6, 1.4])
self.cmnp = ops.CropMirrorNormalize(device="gpu",
output_dtype=types.FLOAT,
output_layout=types.NCHW,
image_type=color_space_type,
mean=0. if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255],
std=1. if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255])
self.coin = ops.CoinFlip(probability=0.5)

def define_graph(self):
rng = self.coin()
self.jpegs, self.labels = self.input(name="Reader")
images = self.decode(self.jpegs)
images = self.res(images)
images = self.twist(images, saturation=self.jitter_rng(),
contrast=self.jitter_rng(), brightness=self.jitter_rng())
output = self.cmnp(images, mirror=rng)
return [output, self.labels]


class HybridValPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, seed=12, local_rank=0, world_size=1,
spos_pre=False, shuffle=False):
super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id)
color_space_type = types.BGR if spos_pre else types.RGB
self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size,
random_shuffle=shuffle)
self.decode = ops.ImageDecoder(device="mixed", output_type=color_space_type)
self.res = ops.Resize(device="gpu", resize_shorter=size,
interp_type=types.INTERP_LINEAR if spos_pre else types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device="gpu",
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=color_space_type,
mean=0. if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255],
std=1. if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255])

def define_graph(self):
self.jpegs, self.labels = self.input(name="Reader")
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images)
return [output, self.labels]


class ClassificationWrapper:
def __init__(self, loader, size):
self.loader = loader
self.size = size

def __iter__(self):
return self

def __next__(self):
data = next(self.loader)
return data[0]["data"], data[0]["label"].view(-1).long().cuda(non_blocking=True)

def __len__(self):
return self.size


def get_imagenet_iter_dali(split, image_dir, batch_size, num_threads, crop=224, val_size=256,
spos_preprocessing=False, seed=12, shuffle=False, device_id=None):
world_size, local_rank = 1, 0
if device_id is None:
device_id = torch.cuda.device_count() - 1 # use last gpu
if split == "train":
pipeline = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id,
data_dir=os.path.join(image_dir, "train"), seed=seed,
crop=crop, world_size=world_size, local_rank=local_rank,
spos_pre=spos_preprocessing)
elif split == "val":
pipeline = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id,
data_dir=os.path.join(image_dir, "val"), seed=seed,
crop=crop, size=val_size, world_size=world_size, local_rank=local_rank,
spos_pre=spos_preprocessing, shuffle=shuffle)
else:
raise AssertionError
pipeline.build()
num_samples = pipeline.epoch_size("Reader")
return ClassificationWrapper(
DALIClassificationIterator(pipeline, size=num_samples, fill_last_batch=split == "train",
auto_reset=True), (num_samples + batch_size - 1) // batch_size)
Loading

0 comments on commit 6f256c7

Please sign in to comment.