diff --git a/docs/en_US/NAS/retiarii/Tutorial.rst b/docs/en_US/NAS/retiarii/Tutorial.rst index fb98cce232..ca5821f136 100644 --- a/docs/en_US/NAS/retiarii/Tutorial.rst +++ b/docs/en_US/NAS/retiarii/Tutorial.rst @@ -1,7 +1,7 @@ Neural Architecture Search with Retiarii (Alpha) ================================================ -*This is a pre-release, its interfaces may subject to minor changes. The roadmap of this figure is: experimental in V2.0 -> alpha version in V2.1 -> beta version in V2.2 -> official release in V2.3. Feel free to give us your comments and suggestions.* +*This is a pre-release, its interfaces may subject to minor changes. The roadmap of this feature is: experimental in V2.0 -> alpha version in V2.1 -> beta version in V2.2 -> official release in V2.3. Feel free to give us your comments and suggestions.* `Retiarii `__ is a new framework to support neural architecture search and hyper-parameter tuning. It allows users to express various search space with high flexibility, to reuse many SOTA search algorithms, and to leverage system level optimizations to speed up the search process. This framework provides the following new user experiences. @@ -197,6 +197,8 @@ After all the above are prepared, it is time to start an experiment to do the mo The complete code of a simple MNIST example can be found :githublink:`here `. +**Local Debug Mode**: When running an experiment, it is easy to get some trivial errors in trial code, such as shape mismatch, undefined variable. To quickly fix these kinds of errors, we provide local debug mode which locally applies mutators once and runs only that generated model. To use local debug mode, users can simply invoke the API `debug_mutated_model(base_model, trainer, applied_mutators)`. + Visualize the Experiment ------------------------ diff --git a/nni/retiarii/experiment/pytorch.py b/nni/retiarii/experiment/pytorch.py index af2eb989f5..09e2636e2a 100644 --- a/nni/retiarii/experiment/pytorch.py +++ b/nni/retiarii/experiment/pytorch.py @@ -98,6 +98,47 @@ def _validation_rules(self): 'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class') } +def preprocess_model(base_model, trainer, applied_mutators): + try: + script_module = torch.jit.script(base_model) + except Exception as e: + _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') + raise e + base_model_ir = convert_to_graph(script_module, base_model) + base_model_ir.evaluator = trainer + + # handle inline mutations + mutators = process_inline_mutation(base_model_ir) + if mutators is not None and applied_mutators: + raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, ' + 'do not use mutators when you use LayerChoice/InputChoice') + if mutators is not None: + applied_mutators = mutators + return base_model_ir, applied_mutators + +def debug_mutated_model(base_model, trainer, applied_mutators): + """ + Locally run only one trial without launching an experiment for debug purpose, then exit. + For example, it can be used to quickly check shape mismatch. + + Specifically, it applies mutators (default to choose the first candidate for the choices) + to generate a new model, then run this model locally. + + Parameters + ---------- + base_model : nni.retiarii.nn.pytorch.nn.Module + the base model + trainer : nni.retiarii.evaluator + the training class of the generated models + applied_mutators : list + a list of mutators that will be applied on the base model for generating a new model + """ + base_model_ir, applied_mutators = preprocess_model(base_model, trainer, applied_mutators) + from ..strategy import _LocalDebugStrategy + strategy = _LocalDebugStrategy() + strategy.run(base_model_ir, applied_mutators) + _logger.info('local debug completed!') + class RetiariiExperiment(Experiment): def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotTrainer], @@ -116,31 +157,13 @@ def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotT self._proc: Optional[Popen] = None self._pipe: Optional[Pipe] = None - self._strategy_thread: Optional[Thread] = None - def _start_strategy(self): - try: - script_module = torch.jit.script(self.base_model) - except Exception as e: - _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') - raise e - base_model_ir = convert_to_graph(script_module, self.base_model) - base_model_ir.evaluator = self.trainer - - # handle inline mutations - mutators = process_inline_mutation(base_model_ir) - if mutators is not None and self.applied_mutators: - raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, ' - 'do not use mutators when you use LayerChoice/InputChoice') - if mutators is not None: - self.applied_mutators = mutators + base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.trainer, self.applied_mutators) - _logger.info('Starting strategy...') - # This is not intuitive and not friendly for debugging (setting breakpoints). Will refactor later. - self._strategy_thread = Thread(target=self.strategy.run, args=(base_model_ir, self.applied_mutators)) - self._strategy_thread.start() - _logger.info('Strategy started!') - Thread(target=self._strategy_monitor).start() + _logger.info('Start strategy...') + self.strategy.run(base_model_ir, self.applied_mutators) + _logger.info('Strategy exit') + self._dispatcher.mark_experiment_as_ending() def start(self, port: int = 8080, debug: bool = False) -> None: """ @@ -185,15 +208,15 @@ def start(self, port: int = 8080, debug: bool = False) -> None: msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL _logger.info(msg) + Thread(target=self._check_exp_status).start() self._start_strategy() + # TODO: the experiment should be completed, when strategy exits and there is no running job + # _logger.info('Waiting for submitted trial jobs to finish...') + _logger.info('Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...') def _create_dispatcher(self): return self._dispatcher - def _strategy_monitor(self): - self._strategy_thread.join() - self._dispatcher.mark_experiment_as_ending() - def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str: """ Run the experiment. @@ -204,15 +227,14 @@ def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = else: assert config is not None, 'You are using classic search mode, config cannot be None!' self.config = config - self._run(port, debug) + self.start(port, debug) - def _run(self, port: int = 8080, debug: bool = False) -> bool: + def _check_exp_status(self) -> bool: """ Run the experiment. This function will block until experiment finish or error. Return `True` when experiment done; or return `False` when experiment failed. """ - self.start(port, debug) try: while True: time.sleep(10) diff --git a/nni/retiarii/strategy/__init__.py b/nni/retiarii/strategy/__init__.py index f6a981f03f..e3cd6c5591 100644 --- a/nni/retiarii/strategy/__init__.py +++ b/nni/retiarii/strategy/__init__.py @@ -5,3 +5,4 @@ from .bruteforce import Random, GridSearch from .evolution import RegularizedEvolution from .tpe_strategy import TPEStrategy +from .local_debug_strategy import _LocalDebugStrategy diff --git a/nni/retiarii/strategy/local_debug_strategy.py b/nni/retiarii/strategy/local_debug_strategy.py new file mode 100644 index 0000000000..743d6b2fc6 --- /dev/null +++ b/nni/retiarii/strategy/local_debug_strategy.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import os +import random +import string + +from .. import Sampler, codegen, utils +from ..execution.base import BaseGraphData +from .base import BaseStrategy + +_logger = logging.getLogger(__name__) + +class ChooseFirstSampler(Sampler): + def choice(self, candidates, mutator, model, index): + return candidates[0] + +class _LocalDebugStrategy(BaseStrategy): + """ + This class is supposed to be used internally, for debugging trial mutation + """ + + def run_one_model(self, model): + graph_data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator) + random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6)) + file_name = f'_generated_model/{random_str}.py' + os.makedirs(os.path.dirname(file_name), exist_ok=True) + with open(file_name, 'w') as f: + f.write(graph_data.model_script) + model_cls = utils.import_(f'_generated_model.{random_str}._model') + graph_data.evaluator._execute(model_cls) + os.remove(file_name) + + def run(self, base_model, applied_mutators): + _logger.info('local debug strategy has been started.') + model = base_model + _logger.debug('New model created. Applied mutators: %s', str(applied_mutators)) + choose_first_sampler = ChooseFirstSampler() + for mutator in applied_mutators: + mutator.bind_sampler(choose_first_sampler) + model = mutator.apply(model) + # directly run models + self.run_one_model(model) diff --git a/test/retiarii_test/mnist/test.py b/test/retiarii_test/mnist/test.py index 2f128496c4..3aeaa15115 100644 --- a/test/retiarii_test/mnist/test.py +++ b/test/retiarii_test/mnist/test.py @@ -5,7 +5,7 @@ import nni.retiarii.evaluator.pytorch.lightning as pl import torch.nn.functional as F from nni.retiarii import serialize -from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment +from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment, debug_mutated_model from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import MNIST @@ -42,6 +42,10 @@ def forward(self, x): val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), max_epochs=2) + # uncomment the following two lines to debug a generated model + #debug_mutated_model(base_model, trainer, []) + #exit(0) + simple_strategy = strategy.Random() exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)