This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Single Path One Shot #1849
Merged
+1,215
−0
Merged
Single Path One Shot #1849
Changes from all commits
Commits
Show all changes
59 commits
Select commit
Hold shift + click to select a range
0d1ca70
checkpoint
5363aa4
checkpoint
a4fc9cc
checkpoint
9bc24b0
checkpoint
ultmaster 7b6ff0d
checkpoint
ultmaster a412ac9
checkpoint
c720373
checkpoint
ultmaster 2edf637
checkpoint
ultmaster 990932b
checkpoint
ultmaster 117399c
finish tester
b4763d9
fix bugs
ultmaster dfdb9fb
checkpoint
f16256d
checkpoint
3115531
fix a few issues
ultmaster 971822c
add model checkpoint
ultmaster 5f51bb5
update
ultmaster 69261af
fix evolution tuner
ultmaster 79ad154
Merge branch 'nas-spos' of github.com:ultmaster/nni into nas-spos
ultmaster e3dddf1
update
ultmaster bc38366
update training from scratch
ultmaster a445e54
add decision
99b3b74
decision class track in
bdfc2e7
decision class track in
1a55007
update from scratch training code
c7a10d2
update
ultmaster fc1eb99
update
ultmaster 34ffa31
fix cur_step error
ultmaster ef387c9
update
ultmaster 082abcd
update
ultmaster 28c5b2d
update
ultmaster ff2d2e7
update format
ultmaster c034b0a
update
ultmaster 8f77321
update
ultmaster 1f29960
update
ultmaster e63c3f3
update
ultmaster 048d604
update
ultmaster 45d0d7a
update
ultmaster d498a38
update
ultmaster dbe8680
update
ultmaster f4e893d
update
ultmaster c24322a
update
ultmaster 138764e
updaste
ultmaster 7fb280a
fix pylint
ultmaster 001c581
update batch size
ultmaster 5a00af5
update
ultmaster 4cef622
add evolution doc
ca47a5b
remove decision
5c2fbd2
add docstring
e8d67ca
add docstring
18bd184
improve docs
38bc071
Merge remote-tracking branch 'upstream/master' into nas-spos-2
489c6de
improve architecture readability
752c7d3
add note for provided archit
a60e8e5
add license
63271ab
update
9871fa6
use enum string
86b34e6
add reproduction results
7ba24b0
add reproduction results
ba009a7
add reproduction results
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
data | ||
checkpoints | ||
runs | ||
nni_auto_gen_search_space.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Single Path One-Shot Neural Architecture Search with Uniform Sampling | ||
|
||
Single Path One-Shot by Megvii Research. [Paper link](https://arxiv.org/abs/1904.00420). [Official repo](https://github.com/megvii-model/SinglePathOneShot). | ||
|
||
Block search only. Channel search is not supported yet. | ||
|
||
Only GPU version is provided here. | ||
|
||
## Preparation | ||
|
||
### Requirements | ||
|
||
* PyTorch >= 1.2 | ||
* NVIDIA DALI >= 0.16 as we use DALI to accelerate the data loading of ImageNet. [Installation guide](https://docs.nvidia.com/deeplearning/sdk/dali-developer-guide/docs/installation.html) | ||
|
||
### Data | ||
|
||
Need to download the flops lookup table from [here](https://1drv.ms/u/s!Am_mmG2-KsrnajesvSdfsq_cN48?e=aHVppN). | ||
Put `op_flops_dict.pkl` and `checkpoint-150000.pth.tar` (if you don't want to retrain the supernet) under `data` directory. | ||
|
||
Prepare ImageNet in the standard format (follow the script [here](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4)). Link it to `data/imagenet` will be more convenient. | ||
|
||
After preparation, it's expected to have the following code structure: | ||
|
||
``` | ||
spos | ||
├── architecture_final.json | ||
├── blocks.py | ||
├── config_search.yml | ||
├── data | ||
│ ├── imagenet | ||
│ │ ├── train | ||
│ │ └── val | ||
│ └── op_flops_dict.pkl | ||
├── dataloader.py | ||
├── network.py | ||
├── readme.md | ||
├── scratch.py | ||
├── supernet.py | ||
├── tester.py | ||
├── tuner.py | ||
└── utils.py | ||
``` | ||
|
||
## Step 1. Train Supernet | ||
|
||
``` | ||
python supernet.py | ||
``` | ||
|
||
Will export the checkpoint to checkpoints directory, for the next step. | ||
|
||
NOTE: The data loading used in the official repo is [slightly different from usual](https://github.com/megvii-model/SinglePathOneShot/issues/5), as they use BGR tensor and keep the values between 0 and 255 intentionally to align with their own DL framework. The option `--spos-preprocessing` will simulate the behavior used originally and enable you to use the checkpoints pretrained. | ||
|
||
## Step 2. Evolution Search | ||
|
||
Single Path One-Shot leverages evolution algorithm to search for the best architecture. The tester, which is responsible for testing the sampled architecture, recalculates all the batch norm for a subset of training images, and evaluates the architecture on the full validation set. | ||
|
||
To have a search space ready for NNI framework, first run | ||
|
||
``` | ||
nnictl ss_gen -t "python tester.py" | ||
``` | ||
|
||
This will generate a file called `nni_auto_gen_search_space.json`, which is a serialized representation of your search space. | ||
|
||
Then search with evolution tuner. | ||
|
||
``` | ||
nnictl create --config config_search.yml | ||
``` | ||
|
||
The final architecture exported from every epoch of evolution can be found in `checkpoints` under the working directory of your tuner, which, by default, is `$HOME/nni/experiments/your_experiment_id/log`. | ||
|
||
## Step 3. Train from Scratch | ||
|
||
``` | ||
python scratch.py | ||
``` | ||
|
||
By default, it will use `architecture_final.json`. This architecture is provided by the official repo (converted into NNI format). You can use any architecture (e.g., the architecture found in step 2) with `--fixed-arc` option. | ||
|
||
## Current Reproduction Results | ||
|
||
Reproduction is still undergoing. Due to the gap between official release and original paper, we compare our current results with official repo (our run) and paper. | ||
|
||
* Evolution phase is almost aligned with official repo. Our evolution algorithm shows a converging trend and reaches ~65% accuracy at the end of search. Nevertheless, this result is not on par with paper. For details, please refer to [this issue](https://github.com/megvii-model/SinglePathOneShot/issues/6). | ||
* Retrain phase is not aligned. Our retraining code, which uses the architecture released by the authors, reaches 72.14% accuracy, still having a gap towards 73.61% by official release and 74.3% reported in original paper. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
{ | ||
"LayerChoice1": [false, false, true, false], | ||
"LayerChoice2": [false, true, false, false], | ||
"LayerChoice3": [true, false, false, false], | ||
"LayerChoice4": [false, true, false, false], | ||
"LayerChoice5": [false, false, true, false], | ||
"LayerChoice6": [true, false, false, false], | ||
"LayerChoice7": [false, false, true, false], | ||
"LayerChoice8": [true, false, false, false], | ||
"LayerChoice9": [false, false, true, false], | ||
"LayerChoice10": [true, false, false, false], | ||
"LayerChoice11": [false, false, true, false], | ||
"LayerChoice12": [false, false, false, true], | ||
"LayerChoice13": [true, false, false, false], | ||
"LayerChoice14": [true, false, false, false], | ||
"LayerChoice15": [true, false, false, false], | ||
"LayerChoice16": [true, false, false, false], | ||
"LayerChoice17": [false, false, false, true], | ||
"LayerChoice18": [false, false, true, false], | ||
"LayerChoice19": [false, false, false, true], | ||
"LayerChoice20": [false, false, false, true] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class ShuffleNetBlock(nn.Module): | ||
""" | ||
When stride = 1, the block receives input with 2 * inp channels. Otherwise inp channels. | ||
""" | ||
|
||
def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp"): | ||
super().__init__() | ||
assert stride in [1, 2] | ||
assert ksize in [3, 5, 7] | ||
self.channels = inp // 2 if stride == 1 else inp | ||
self.inp = inp | ||
self.oup = oup | ||
self.mid_channels = mid_channels | ||
self.ksize = ksize | ||
self.stride = stride | ||
self.pad = ksize // 2 | ||
self.oup_main = oup - self.channels | ||
assert self.oup_main > 0 | ||
|
||
self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence)) | ||
|
||
if stride == 2: | ||
self.branch_proj = nn.Sequential( | ||
# dw | ||
nn.Conv2d(self.channels, self.channels, ksize, stride, self.pad, | ||
groups=self.channels, bias=False), | ||
nn.BatchNorm2d(self.channels, affine=False), | ||
# pw-linear | ||
nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False), | ||
nn.BatchNorm2d(self.channels, affine=False), | ||
nn.ReLU(inplace=True) | ||
) | ||
|
||
def forward(self, x): | ||
if self.stride == 2: | ||
x_proj, x = self.branch_proj(x), x | ||
else: | ||
x_proj, x = self._channel_shuffle(x) | ||
return torch.cat((x_proj, self.branch_main(x)), 1) | ||
|
||
def _decode_point_depth_conv(self, sequence): | ||
result = [] | ||
first_depth = first_point = True | ||
pc = c = self.channels | ||
for i, token in enumerate(sequence): | ||
# compute output channels of this conv | ||
if i + 1 == len(sequence): | ||
assert token == "p", "Last conv must be point-wise conv." | ||
c = self.oup_main | ||
elif token == "p" and first_point: | ||
c = self.mid_channels | ||
if token == "d": | ||
# depth-wise conv | ||
assert pc == c, "Depth-wise conv must not change channels." | ||
result.append(nn.Conv2d(pc, c, self.ksize, self.stride if first_depth else 1, self.pad, | ||
groups=c, bias=False)) | ||
result.append(nn.BatchNorm2d(c, affine=False)) | ||
first_depth = False | ||
elif token == "p": | ||
# point-wise conv | ||
result.append(nn.Conv2d(pc, c, 1, 1, 0, bias=False)) | ||
result.append(nn.BatchNorm2d(c, affine=False)) | ||
result.append(nn.ReLU(inplace=True)) | ||
first_point = False | ||
else: | ||
raise ValueError("Conv sequence must be d and p.") | ||
pc = c | ||
return result | ||
|
||
def _channel_shuffle(self, x): | ||
bs, num_channels, height, width = x.data.size() | ||
assert (num_channels % 4 == 0) | ||
x = x.reshape(bs * num_channels // 2, 2, height * width) | ||
x = x.permute(1, 0, 2) | ||
x = x.reshape(2, -1, num_channels // 2, height, width) | ||
return x[0], x[1] | ||
|
||
|
||
class ShuffleXceptionBlock(ShuffleNetBlock): | ||
|
||
def __init__(self, inp, oup, mid_channels, stride): | ||
super().__init__(inp, oup, mid_channels, 3, stride, "dpdpdp") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
authorName: unknown | ||
experimentName: SPOS Search | ||
trialConcurrency: 4 | ||
maxExecDuration: 7d | ||
maxTrialNum: 99999 | ||
trainingServicePlatform: local | ||
searchSpacePath: nni_auto_gen_search_space.json | ||
useAnnotation: false | ||
tuner: | ||
codeDir: . | ||
classFileName: tuner.py | ||
className: EvolutionWithFlops | ||
trial: | ||
command: python tester.py --imagenet-dir /path/to/your/imagenet --spos-prep | ||
codeDir: . | ||
gpuNum: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import os | ||
|
||
import nvidia.dali.ops as ops | ||
import nvidia.dali.types as types | ||
import torch.utils.data | ||
from nvidia.dali.pipeline import Pipeline | ||
from nvidia.dali.plugin.pytorch import DALIClassificationIterator | ||
|
||
|
||
class HybridTrainPipe(Pipeline): | ||
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed=12, local_rank=0, world_size=1, | ||
spos_pre=False): | ||
super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id) | ||
color_space_type = types.BGR if spos_pre else types.RGB | ||
self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, random_shuffle=True) | ||
self.decode = ops.ImageDecoder(device="mixed", output_type=color_space_type) | ||
self.res = ops.RandomResizedCrop(device="gpu", size=crop, | ||
interp_type=types.INTERP_LINEAR if spos_pre else types.INTERP_TRIANGULAR) | ||
self.twist = ops.ColorTwist(device="gpu") | ||
self.jitter_rng = ops.Uniform(range=[0.6, 1.4]) | ||
self.cmnp = ops.CropMirrorNormalize(device="gpu", | ||
output_dtype=types.FLOAT, | ||
output_layout=types.NCHW, | ||
image_type=color_space_type, | ||
mean=0. if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255], | ||
std=1. if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255]) | ||
self.coin = ops.CoinFlip(probability=0.5) | ||
|
||
def define_graph(self): | ||
rng = self.coin() | ||
self.jpegs, self.labels = self.input(name="Reader") | ||
images = self.decode(self.jpegs) | ||
images = self.res(images) | ||
images = self.twist(images, saturation=self.jitter_rng(), | ||
contrast=self.jitter_rng(), brightness=self.jitter_rng()) | ||
output = self.cmnp(images, mirror=rng) | ||
return [output, self.labels] | ||
|
||
|
||
class HybridValPipe(Pipeline): | ||
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, seed=12, local_rank=0, world_size=1, | ||
spos_pre=False, shuffle=False): | ||
super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id) | ||
color_space_type = types.BGR if spos_pre else types.RGB | ||
self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, | ||
random_shuffle=shuffle) | ||
self.decode = ops.ImageDecoder(device="mixed", output_type=color_space_type) | ||
self.res = ops.Resize(device="gpu", resize_shorter=size, | ||
interp_type=types.INTERP_LINEAR if spos_pre else types.INTERP_TRIANGULAR) | ||
self.cmnp = ops.CropMirrorNormalize(device="gpu", | ||
output_dtype=types.FLOAT, | ||
output_layout=types.NCHW, | ||
crop=(crop, crop), | ||
image_type=color_space_type, | ||
mean=0. if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255], | ||
std=1. if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255]) | ||
|
||
def define_graph(self): | ||
self.jpegs, self.labels = self.input(name="Reader") | ||
images = self.decode(self.jpegs) | ||
images = self.res(images) | ||
output = self.cmnp(images) | ||
return [output, self.labels] | ||
|
||
|
||
class ClassificationWrapper: | ||
def __init__(self, loader, size): | ||
self.loader = loader | ||
self.size = size | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def __next__(self): | ||
data = next(self.loader) | ||
return data[0]["data"], data[0]["label"].view(-1).long().cuda(non_blocking=True) | ||
|
||
def __len__(self): | ||
return self.size | ||
|
||
|
||
def get_imagenet_iter_dali(split, image_dir, batch_size, num_threads, crop=224, val_size=256, | ||
spos_preprocessing=False, seed=12, shuffle=False, device_id=None): | ||
world_size, local_rank = 1, 0 | ||
if device_id is None: | ||
device_id = torch.cuda.device_count() - 1 # use last gpu | ||
if split == "train": | ||
pipeline = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id, | ||
data_dir=os.path.join(image_dir, "train"), seed=seed, | ||
crop=crop, world_size=world_size, local_rank=local_rank, | ||
spos_pre=spos_preprocessing) | ||
elif split == "val": | ||
pipeline = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id, | ||
data_dir=os.path.join(image_dir, "val"), seed=seed, | ||
crop=crop, size=val_size, world_size=world_size, local_rank=local_rank, | ||
spos_pre=spos_preprocessing, shuffle=shuffle) | ||
else: | ||
raise AssertionError | ||
pipeline.build() | ||
num_samples = pipeline.epoch_size("Reader") | ||
return ClassificationWrapper( | ||
DALIClassificationIterator(pipeline, size=num_samples, fill_last_batch=split == "train", | ||
auto_reset=True), (num_samples + batch_size - 1) // batch_size) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why using this package?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To accelerate the data loading by PyTorch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more detail? by accelerating what? what's the major difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ImageNet dataloading and augmentation is slow and inefficient. Running with PyTorch built-in dataloader induces bottleneck on CPU and memory. Using dali brings over 10x speedup on our workstation (4 GTX 1080 and a 12-core CPU).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main difference is to do data decoding and augmentation on GPU. This also brings some changes to the interface of dataloader.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it and thx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to mention this requirement in doc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, should offer a requirement.txt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's already mentioned in docs. DALI needs different installation command for cuda 9 and 10. Can't do them all in a
requirements.txt
: https://docs.nvidia.com/deeplearning/sdk/dali-developer-guide/docs/installation.htmlThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can it execute by a sh script?