Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] support debug mode for easy debuggability #3476

Merged
merged 8 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/en_US/NAS/retiarii/Tutorial.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Neural Architecture Search with Retiarii (Alpha)
================================================

*This is a pre-release, its interfaces may subject to minor changes. The roadmap of this figure is: experimental in V2.0 -> alpha version in V2.1 -> beta version in V2.2 -> official release in V2.3. Feel free to give us your comments and suggestions.*
*This is a pre-release, its interfaces may subject to minor changes. The roadmap of this feature is: experimental in V2.0 -> alpha version in V2.1 -> beta version in V2.2 -> official release in V2.3. Feel free to give us your comments and suggestions.*

`Retiarii <https://www.usenix.org/system/files/osdi20-zhang_quanlu.pdf>`__ is a new framework to support neural architecture search and hyper-parameter tuning. It allows users to express various search space with high flexibility, to reuse many SOTA search algorithms, and to leverage system level optimizations to speed up the search process. This framework provides the following new user experiences.

Expand Down Expand Up @@ -197,6 +197,8 @@ After all the above are prepared, it is time to start an experiment to do the mo

The complete code of a simple MNIST example can be found :githublink:`here <test/retiarii_test/mnist/test.py>`.

**Local Debug Mode**: When running an experiment, it is easy to get some trivial errors in trial code, such as shape mismatch, undefined variable. To quickly fix these kinds of errors, we provide local debug mode which locally applies mutators once and runs only that generated model. To use local debug mode, users can simply invoke the API `debug_mutated_model(base_model, trainer, applied_mutators)`.

Visualize the Experiment
------------------------

Expand Down
82 changes: 52 additions & 30 deletions nni/retiarii/experiment/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,47 @@ def _validation_rules(self):
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}

def preprocess_model(base_model, trainer, applied_mutators):
try:
script_module = torch.jit.script(base_model)
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
base_model_ir = convert_to_graph(script_module, base_model)
base_model_ir.evaluator = trainer

# handle inline mutations
mutators = process_inline_mutation(base_model_ir)
if mutators is not None and applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
applied_mutators = mutators
return base_model_ir, applied_mutators

def debug_mutated_model(base_model, trainer, applied_mutators):
"""
Locally run only one trial without launching an experiment for debug purpose, then exit.
For example, it can be used to quickly check shape mismatch.

Specifically, it applies mutators (default to choose the first candidate for the choices)
to generate a new model, then run this model locally.

Parameters
----------
base_model : nni.retiarii.nn.pytorch.nn.Module
the base model
trainer : nni.retiarii.evaluator
the training class of the generated models
applied_mutators : list
a list of mutators that will be applied on the base model for generating a new model
"""
base_model_ir, applied_mutators = preprocess_model(base_model, trainer, applied_mutators)
from ..strategy import _LocalDebugStrategy
strategy = _LocalDebugStrategy()
strategy.run(base_model_ir, applied_mutators)
_logger.info('local debug completed!')


class RetiariiExperiment(Experiment):
def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotTrainer],
Expand All @@ -116,31 +157,13 @@ def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotT
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None

self._strategy_thread: Optional[Thread] = None

def _start_strategy(self):
try:
script_module = torch.jit.script(self.base_model)
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
base_model_ir = convert_to_graph(script_module, self.base_model)
base_model_ir.evaluator = self.trainer

# handle inline mutations
mutators = process_inline_mutation(base_model_ir)
if mutators is not None and self.applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
self.applied_mutators = mutators
base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.trainer, self.applied_mutators)

_logger.info('Starting strategy...')
# This is not intuitive and not friendly for debugging (setting breakpoints). Will refactor later.
self._strategy_thread = Thread(target=self.strategy.run, args=(base_model_ir, self.applied_mutators))
self._strategy_thread.start()
_logger.info('Strategy started!')
Thread(target=self._strategy_monitor).start()
_logger.info('Start strategy...')
self.strategy.run(base_model_ir, self.applied_mutators)
_logger.info('Strategy exit')
self._dispatcher.mark_experiment_as_ending()

def start(self, port: int = 8080, debug: bool = False) -> None:
"""
Expand Down Expand Up @@ -185,15 +208,15 @@ def start(self, port: int = 8080, debug: bool = False) -> None:
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL
_logger.info(msg)

Thread(target=self._check_exp_status).start()
self._start_strategy()
# TODO: the experiment should be completed, when strategy exits and there is no running job
# _logger.info('Waiting for submitted trial jobs to finish...')
_logger.info('Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...')

def _create_dispatcher(self):
return self._dispatcher

def _strategy_monitor(self):
self._strategy_thread.join()
self._dispatcher.mark_experiment_as_ending()

def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str:
"""
Run the experiment.
Expand All @@ -204,15 +227,14 @@ def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool =
else:
assert config is not None, 'You are using classic search mode, config cannot be None!'
self.config = config
self._run(port, debug)
self.start(port, debug)

def _run(self, port: int = 8080, debug: bool = False) -> bool:
def _check_exp_status(self) -> bool:
"""
Run the experiment.
This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
"""
self.start(port, debug)
try:
while True:
time.sleep(10)
Expand Down
1 change: 1 addition & 0 deletions nni/retiarii/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .bruteforce import Random, GridSearch
from .evolution import RegularizedEvolution
from .tpe_strategy import TPEStrategy
from .local_debug_strategy import _LocalDebugStrategy
44 changes: 44 additions & 0 deletions nni/retiarii/strategy/local_debug_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import os
import random
import string

from .. import Sampler, codegen, utils
from ..execution.base import BaseGraphData
from .base import BaseStrategy

_logger = logging.getLogger(__name__)

class ChooseFirstSampler(Sampler):
def choice(self, candidates, mutator, model, index):
return candidates[0]

class _LocalDebugStrategy(BaseStrategy):
"""
This class is supposed to be used internally, for debugging trial mutation
"""

def run_one_model(self, model):
graph_data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(graph_data.model_script)
model_cls = utils.import_(f'_generated_model.{random_str}._model')
graph_data.evaluator._execute(model_cls)
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
os.remove(file_name)

def run(self, base_model, applied_mutators):
_logger.info('local debug strategy has been started.')
model = base_model
_logger.debug('New model created. Applied mutators: %s', str(applied_mutators))
choose_first_sampler = ChooseFirstSampler()
for mutator in applied_mutators:
mutator.bind_sampler(choose_first_sampler)
model = mutator.apply(model)
# directly run models
self.run_one_model(model)
6 changes: 5 additions & 1 deletion test/retiarii_test/mnist/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn.functional as F
from nni.retiarii import serialize
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment, debug_mutated_model
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
Expand Down Expand Up @@ -42,6 +42,10 @@ def forward(self, x):
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=2)

# uncomment the following two lines to debug a generated model
#debug_mutated_model(base_model, trainer, [])
#exit(0)

simple_strategy = strategy.Random()

exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)
Expand Down