From efda084feef186510650ce02467e994deabb57d7 Mon Sep 17 00:00:00 2001 From: quzha Date: Thu, 25 Mar 2021 19:14:33 +0800 Subject: [PATCH 1/8] support debug mode for retiarii --- nni/retiarii/experiment/pytorch.py | 16 ++++++- nni/retiarii/strategy/__init__.py | 1 + nni/retiarii/strategy/local_debug_strategy.py | 42 +++++++++++++++++++ test/retiarii_test/mnist/test.py | 3 +- 4 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 nni/retiarii/strategy/local_debug_strategy.py diff --git a/nni/retiarii/experiment/pytorch.py b/nni/retiarii/experiment/pytorch.py index af2eb989f5..b04b8e530f 100644 --- a/nni/retiarii/experiment/pytorch.py +++ b/nni/retiarii/experiment/pytorch.py @@ -118,7 +118,7 @@ def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotT self._strategy_thread: Optional[Thread] = None - def _start_strategy(self): + def _preprocess_model(self): try: script_module = torch.jit.script(self.base_model) except Exception as e: @@ -134,6 +134,10 @@ def _start_strategy(self): 'do not use mutators when you use LayerChoice/InputChoice') if mutators is not None: self.applied_mutators = mutators + return base_model_ir, mutators + + def _start_strategy(self): + base_model_ir, _ = self._preprocess_model() _logger.info('Starting strategy...') # This is not intuitive and not friendly for debugging (setting breakpoints). Will refactor later. @@ -194,6 +198,16 @@ def _strategy_monitor(self): self._strategy_thread.join() self._dispatcher.mark_experiment_as_ending() + def local_debug_run(self): + """ + Locally run one trial for debug, then exit + """ + base_model_ir, applied_mutators = self._preprocess_model() + from ..strategy import LocalDebugStrategy + strategy = LocalDebugStrategy() + strategy.run(base_model_ir, applied_mutators) + _logger.info('local debug completed!') + def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str: """ Run the experiment. diff --git a/nni/retiarii/strategy/__init__.py b/nni/retiarii/strategy/__init__.py index f6a981f03f..1be72af4af 100644 --- a/nni/retiarii/strategy/__init__.py +++ b/nni/retiarii/strategy/__init__.py @@ -5,3 +5,4 @@ from .bruteforce import Random, GridSearch from .evolution import RegularizedEvolution from .tpe_strategy import TPEStrategy +from .local_debug_strategy import LocalDebugStrategy diff --git a/nni/retiarii/strategy/local_debug_strategy.py b/nni/retiarii/strategy/local_debug_strategy.py new file mode 100644 index 0000000000..9c9dbba85c --- /dev/null +++ b/nni/retiarii/strategy/local_debug_strategy.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import os +import random +import string +import time + +from .. import Sampler, codegen, utils +from ..execution.base import BaseGraphData +from .base import BaseStrategy + +_logger = logging.getLogger(__name__) + +class ChooseFirstSampler(Sampler): + def choice(self, candidates, mutator, model, index): + return candidates[0] + +class LocalDebugStrategy(BaseStrategy): + + def run_one_model(self, model): + graph_data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator) + random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6)) + file_name = f'_generated_model/{random_str}.py' + os.makedirs(os.path.dirname(file_name), exist_ok=True) + with open(file_name, 'w') as f: + f.write(graph_data.model_script) + model_cls = utils.import_(f'_generated_model.{random_str}._model') + graph_data.evaluator._execute(model_cls) + os.remove(file_name) + + def run(self, base_model, applied_mutators): + _logger.info('local debug strategy has been started.') + model = base_model + _logger.debug('New model created. Applied mutators: %s', str(applied_mutators)) + choose_first_sampler = ChooseFirstSampler() + for mutator in applied_mutators: + mutator.bind_sampler(choose_first_sampler) + model = mutator.apply(model) + # directly run models + self.run_one_model(model) diff --git a/test/retiarii_test/mnist/test.py b/test/retiarii_test/mnist/test.py index 2f128496c4..26c8d96eda 100644 --- a/test/retiarii_test/mnist/test.py +++ b/test/retiarii_test/mnist/test.py @@ -52,4 +52,5 @@ def forward(self, x): exp_config.max_trial_number = 10 exp_config.training_service.use_active_gpu = False - exp.run(exp_config, 8081 + random.randint(0, 100)) + #exp.run(exp_config, 8081 + random.randint(0, 100)) + exp.local_debug_run() From 58e7b890533607b5363c863f3c8a3c2f959463e0 Mon Sep 17 00:00:00 2001 From: quzha Date: Mon, 29 Mar 2021 10:00:50 +0800 Subject: [PATCH 2/8] make strategy the main thread --- examples/trials/mnist-pytorch/config.yml | 10 +++--- nni/retiarii/experiment/pytorch.py | 32 +++++++++---------- nni/retiarii/strategy/local_debug_strategy.py | 1 - test/retiarii_test/mnist/test.py | 4 +-- 4 files changed, 22 insertions(+), 25 deletions(-) diff --git a/examples/trials/mnist-pytorch/config.yml b/examples/trials/mnist-pytorch/config.yml index 00a95216aa..0bd483c545 100644 --- a/examples/trials/mnist-pytorch/config.yml +++ b/examples/trials/mnist-pytorch/config.yml @@ -1,8 +1,8 @@ authorName: default experimentName: example_mnist_pytorch -trialConcurrency: 1 -maxExecDuration: 1h -maxTrialNum: 10 +trialConcurrency: 3 +maxExecDuration: 10000h +maxTrialNum: 10000 #choice: local, remote, pai trainingServicePlatform: local searchSpacePath: search_space.json @@ -11,11 +11,11 @@ useAnnotation: false tuner: #choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner #SMAC (SMAC should be installed through nnictl) - builtinTunerName: TPE + builtinTunerName: SMAC classArgs: #choice: maximize, minimize optimize_mode: maximize trial: command: python3 mnist.py codeDir: . - gpuNum: 0 + gpuNum: 1 diff --git a/nni/retiarii/experiment/pytorch.py b/nni/retiarii/experiment/pytorch.py index b04b8e530f..4a76fd71c5 100644 --- a/nni/retiarii/experiment/pytorch.py +++ b/nni/retiarii/experiment/pytorch.py @@ -116,8 +116,6 @@ def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotT self._proc: Optional[Popen] = None self._pipe: Optional[Pipe] = None - self._strategy_thread: Optional[Thread] = None - def _preprocess_model(self): try: script_module = torch.jit.script(self.base_model) @@ -139,12 +137,10 @@ def _preprocess_model(self): def _start_strategy(self): base_model_ir, _ = self._preprocess_model() - _logger.info('Starting strategy...') - # This is not intuitive and not friendly for debugging (setting breakpoints). Will refactor later. - self._strategy_thread = Thread(target=self.strategy.run, args=(base_model_ir, self.applied_mutators)) - self._strategy_thread.start() - _logger.info('Strategy started!') - Thread(target=self._strategy_monitor).start() + _logger.info('Start strategy...') + self.strategy.run(base_model_ir, self.applied_mutators) + _logger.info('Strategy exit') + self._dispatcher.mark_experiment_as_ending() def start(self, port: int = 8080, debug: bool = False) -> None: """ @@ -189,24 +185,27 @@ def start(self, port: int = 8080, debug: bool = False) -> None: msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL _logger.info(msg) + Thread(target=self._check_exp_status).start() self._start_strategy() + # TODO: the experiment should be completed, when strategy exits and there is no running job + #_logger.info('Strategy exits. Waiting for submitted trial jobs to finish...') + _logger.info('Strategy exits. Waiting for experiment to become DONE...') def _create_dispatcher(self): return self._dispatcher - def _strategy_monitor(self): - self._strategy_thread.join() - self._dispatcher.mark_experiment_as_ending() - - def local_debug_run(self): + def local_debug_run(self, gpu_indices=None): """ - Locally run one trial for debug, then exit + Locally run only one trial without launching an experiment for debug purpose, then exit. + For example, it can be used to quickly check shape mismatch. """ base_model_ir, applied_mutators = self._preprocess_model() from ..strategy import LocalDebugStrategy strategy = LocalDebugStrategy() strategy.run(base_model_ir, applied_mutators) _logger.info('local debug completed!') + self._dispatcher.stopping = True + self._dispatcher = None def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str: """ @@ -218,15 +217,14 @@ def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = else: assert config is not None, 'You are using classic search mode, config cannot be None!' self.config = config - self._run(port, debug) + self.start(port, debug) - def _run(self, port: int = 8080, debug: bool = False) -> bool: + def _check_exp_status(self) -> bool: """ Run the experiment. This function will block until experiment finish or error. Return `True` when experiment done; or return `False` when experiment failed. """ - self.start(port, debug) try: while True: time.sleep(10) diff --git a/nni/retiarii/strategy/local_debug_strategy.py b/nni/retiarii/strategy/local_debug_strategy.py index 9c9dbba85c..a6504b5851 100644 --- a/nni/retiarii/strategy/local_debug_strategy.py +++ b/nni/retiarii/strategy/local_debug_strategy.py @@ -5,7 +5,6 @@ import os import random import string -import time from .. import Sampler, codegen, utils from ..execution.base import BaseGraphData diff --git a/test/retiarii_test/mnist/test.py b/test/retiarii_test/mnist/test.py index 26c8d96eda..fb541cbcff 100644 --- a/test/retiarii_test/mnist/test.py +++ b/test/retiarii_test/mnist/test.py @@ -52,5 +52,5 @@ def forward(self, x): exp_config.max_trial_number = 10 exp_config.training_service.use_active_gpu = False - #exp.run(exp_config, 8081 + random.randint(0, 100)) - exp.local_debug_run() + exp.run(exp_config, 8081 + random.randint(0, 100)) + #exp.local_debug_run() From b86548518608ca4fc7b22a886d8da264dc32d264 Mon Sep 17 00:00:00 2001 From: quzha Date: Mon, 29 Mar 2021 10:02:53 +0800 Subject: [PATCH 3/8] revert change --- examples/trials/mnist-pytorch/config.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/trials/mnist-pytorch/config.yml b/examples/trials/mnist-pytorch/config.yml index 0bd483c545..00a95216aa 100644 --- a/examples/trials/mnist-pytorch/config.yml +++ b/examples/trials/mnist-pytorch/config.yml @@ -1,8 +1,8 @@ authorName: default experimentName: example_mnist_pytorch -trialConcurrency: 3 -maxExecDuration: 10000h -maxTrialNum: 10000 +trialConcurrency: 1 +maxExecDuration: 1h +maxTrialNum: 10 #choice: local, remote, pai trainingServicePlatform: local searchSpacePath: search_space.json @@ -11,11 +11,11 @@ useAnnotation: false tuner: #choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner #SMAC (SMAC should be installed through nnictl) - builtinTunerName: SMAC + builtinTunerName: TPE classArgs: #choice: maximize, minimize optimize_mode: maximize trial: command: python3 mnist.py codeDir: . - gpuNum: 1 + gpuNum: 0 From 1700964e6c0a0fca6dafee2dfbcac7324eb3057a Mon Sep 17 00:00:00 2001 From: quzha Date: Mon, 29 Mar 2021 10:09:28 +0800 Subject: [PATCH 4/8] minor --- nni/retiarii/experiment/pytorch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nni/retiarii/experiment/pytorch.py b/nni/retiarii/experiment/pytorch.py index 4a76fd71c5..b3fdb5f5e2 100644 --- a/nni/retiarii/experiment/pytorch.py +++ b/nni/retiarii/experiment/pytorch.py @@ -188,13 +188,13 @@ def start(self, port: int = 8080, debug: bool = False) -> None: Thread(target=self._check_exp_status).start() self._start_strategy() # TODO: the experiment should be completed, when strategy exits and there is no running job - #_logger.info('Strategy exits. Waiting for submitted trial jobs to finish...') - _logger.info('Strategy exits. Waiting for experiment to become DONE...') + # _logger.info('Waiting for submitted trial jobs to finish...') + _logger.info('Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...') def _create_dispatcher(self): return self._dispatcher - def local_debug_run(self, gpu_indices=None): + def local_debug_run(self): """ Locally run only one trial without launching an experiment for debug purpose, then exit. For example, it can be used to quickly check shape mismatch. From 952a6a60d2e04729da9963c13f3cc4366d6b126b Mon Sep 17 00:00:00 2001 From: quzha Date: Mon, 29 Mar 2021 10:34:01 +0800 Subject: [PATCH 5/8] update doc accordingly --- docs/en_US/NAS/retiarii/Tutorial.rst | 2 ++ nni/retiarii/experiment/pytorch.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/docs/en_US/NAS/retiarii/Tutorial.rst b/docs/en_US/NAS/retiarii/Tutorial.rst index fb98cce232..d0b78b3c09 100644 --- a/docs/en_US/NAS/retiarii/Tutorial.rst +++ b/docs/en_US/NAS/retiarii/Tutorial.rst @@ -197,6 +197,8 @@ After all the above are prepared, it is time to start an experiment to do the mo The complete code of a simple MNIST example can be found :githublink:`here `. +**Local Debug Mode**: When running an experiment, it is easy to get some trivial errors in trial code, such as shape mismatch, undefined variable. To quickly fix these kinds of errors, we provide local debug mode which locally applies mutators once and runs only that generated model. To use local debug mode, users can simply replace ``exp.run(exp_config, 8081)`` in above code snippet with ``exp.local_debug_run()``. + Visualize the Experiment ------------------------ diff --git a/nni/retiarii/experiment/pytorch.py b/nni/retiarii/experiment/pytorch.py index b3fdb5f5e2..d3bb31a356 100644 --- a/nni/retiarii/experiment/pytorch.py +++ b/nni/retiarii/experiment/pytorch.py @@ -198,6 +198,9 @@ def local_debug_run(self): """ Locally run only one trial without launching an experiment for debug purpose, then exit. For example, it can be used to quickly check shape mismatch. + + Specifically, it applies mutators (default to choose the first candidate for the choices) + to generate a new model, then run this model locally. """ base_model_ir, applied_mutators = self._preprocess_model() from ..strategy import LocalDebugStrategy From ca38580bc1588bfa37faf0fa8d909434ef8adc77 Mon Sep 17 00:00:00 2001 From: quzha Date: Mon, 29 Mar 2021 10:39:41 +0800 Subject: [PATCH 6/8] minor --- docs/en_US/NAS/retiarii/Tutorial.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/en_US/NAS/retiarii/Tutorial.rst b/docs/en_US/NAS/retiarii/Tutorial.rst index d0b78b3c09..cb527d9b78 100644 --- a/docs/en_US/NAS/retiarii/Tutorial.rst +++ b/docs/en_US/NAS/retiarii/Tutorial.rst @@ -1,7 +1,7 @@ Neural Architecture Search with Retiarii (Alpha) ================================================ -*This is a pre-release, its interfaces may subject to minor changes. The roadmap of this figure is: experimental in V2.0 -> alpha version in V2.1 -> beta version in V2.2 -> official release in V2.3. Feel free to give us your comments and suggestions.* +*This is a pre-release, its interfaces may subject to minor changes. The roadmap of this feature is: experimental in V2.0 -> alpha version in V2.1 -> beta version in V2.2 -> official release in V2.3. Feel free to give us your comments and suggestions.* `Retiarii `__ is a new framework to support neural architecture search and hyper-parameter tuning. It allows users to express various search space with high flexibility, to reuse many SOTA search algorithms, and to leverage system level optimizations to speed up the search process. This framework provides the following new user experiences. From 4e5a30e9692f506b23db5f3dd9ad8252cc126b2d Mon Sep 17 00:00:00 2001 From: quzha Date: Tue, 6 Apr 2021 09:54:26 +0800 Subject: [PATCH 7/8] refactor local debug mode interface --- nni/retiarii/experiment/pytorch.py | 77 ++++++++++--------- nni/retiarii/strategy/__init__.py | 2 +- nni/retiarii/strategy/local_debug_strategy.py | 5 +- test/retiarii_test/mnist/test.py | 7 +- 4 files changed, 52 insertions(+), 39 deletions(-) diff --git a/nni/retiarii/experiment/pytorch.py b/nni/retiarii/experiment/pytorch.py index d3bb31a356..09e2636e2a 100644 --- a/nni/retiarii/experiment/pytorch.py +++ b/nni/retiarii/experiment/pytorch.py @@ -98,6 +98,47 @@ def _validation_rules(self): 'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class') } +def preprocess_model(base_model, trainer, applied_mutators): + try: + script_module = torch.jit.script(base_model) + except Exception as e: + _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') + raise e + base_model_ir = convert_to_graph(script_module, base_model) + base_model_ir.evaluator = trainer + + # handle inline mutations + mutators = process_inline_mutation(base_model_ir) + if mutators is not None and applied_mutators: + raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, ' + 'do not use mutators when you use LayerChoice/InputChoice') + if mutators is not None: + applied_mutators = mutators + return base_model_ir, applied_mutators + +def debug_mutated_model(base_model, trainer, applied_mutators): + """ + Locally run only one trial without launching an experiment for debug purpose, then exit. + For example, it can be used to quickly check shape mismatch. + + Specifically, it applies mutators (default to choose the first candidate for the choices) + to generate a new model, then run this model locally. + + Parameters + ---------- + base_model : nni.retiarii.nn.pytorch.nn.Module + the base model + trainer : nni.retiarii.evaluator + the training class of the generated models + applied_mutators : list + a list of mutators that will be applied on the base model for generating a new model + """ + base_model_ir, applied_mutators = preprocess_model(base_model, trainer, applied_mutators) + from ..strategy import _LocalDebugStrategy + strategy = _LocalDebugStrategy() + strategy.run(base_model_ir, applied_mutators) + _logger.info('local debug completed!') + class RetiariiExperiment(Experiment): def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotTrainer], @@ -116,26 +157,8 @@ def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotT self._proc: Optional[Popen] = None self._pipe: Optional[Pipe] = None - def _preprocess_model(self): - try: - script_module = torch.jit.script(self.base_model) - except Exception as e: - _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') - raise e - base_model_ir = convert_to_graph(script_module, self.base_model) - base_model_ir.evaluator = self.trainer - - # handle inline mutations - mutators = process_inline_mutation(base_model_ir) - if mutators is not None and self.applied_mutators: - raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, ' - 'do not use mutators when you use LayerChoice/InputChoice') - if mutators is not None: - self.applied_mutators = mutators - return base_model_ir, mutators - def _start_strategy(self): - base_model_ir, _ = self._preprocess_model() + base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.trainer, self.applied_mutators) _logger.info('Start strategy...') self.strategy.run(base_model_ir, self.applied_mutators) @@ -194,22 +217,6 @@ def start(self, port: int = 8080, debug: bool = False) -> None: def _create_dispatcher(self): return self._dispatcher - def local_debug_run(self): - """ - Locally run only one trial without launching an experiment for debug purpose, then exit. - For example, it can be used to quickly check shape mismatch. - - Specifically, it applies mutators (default to choose the first candidate for the choices) - to generate a new model, then run this model locally. - """ - base_model_ir, applied_mutators = self._preprocess_model() - from ..strategy import LocalDebugStrategy - strategy = LocalDebugStrategy() - strategy.run(base_model_ir, applied_mutators) - _logger.info('local debug completed!') - self._dispatcher.stopping = True - self._dispatcher = None - def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str: """ Run the experiment. diff --git a/nni/retiarii/strategy/__init__.py b/nni/retiarii/strategy/__init__.py index 1be72af4af..e3cd6c5591 100644 --- a/nni/retiarii/strategy/__init__.py +++ b/nni/retiarii/strategy/__init__.py @@ -5,4 +5,4 @@ from .bruteforce import Random, GridSearch from .evolution import RegularizedEvolution from .tpe_strategy import TPEStrategy -from .local_debug_strategy import LocalDebugStrategy +from .local_debug_strategy import _LocalDebugStrategy diff --git a/nni/retiarii/strategy/local_debug_strategy.py b/nni/retiarii/strategy/local_debug_strategy.py index a6504b5851..743d6b2fc6 100644 --- a/nni/retiarii/strategy/local_debug_strategy.py +++ b/nni/retiarii/strategy/local_debug_strategy.py @@ -16,7 +16,10 @@ class ChooseFirstSampler(Sampler): def choice(self, candidates, mutator, model, index): return candidates[0] -class LocalDebugStrategy(BaseStrategy): +class _LocalDebugStrategy(BaseStrategy): + """ + This class is supposed to be used internally, for debugging trial mutation + """ def run_one_model(self, model): graph_data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator) diff --git a/test/retiarii_test/mnist/test.py b/test/retiarii_test/mnist/test.py index fb541cbcff..3aeaa15115 100644 --- a/test/retiarii_test/mnist/test.py +++ b/test/retiarii_test/mnist/test.py @@ -5,7 +5,7 @@ import nni.retiarii.evaluator.pytorch.lightning as pl import torch.nn.functional as F from nni.retiarii import serialize -from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment +from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment, debug_mutated_model from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import MNIST @@ -42,6 +42,10 @@ def forward(self, x): val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), max_epochs=2) + # uncomment the following two lines to debug a generated model + #debug_mutated_model(base_model, trainer, []) + #exit(0) + simple_strategy = strategy.Random() exp = RetiariiExperiment(base_model, trainer, [], simple_strategy) @@ -53,4 +57,3 @@ def forward(self, x): exp_config.training_service.use_active_gpu = False exp.run(exp_config, 8081 + random.randint(0, 100)) - #exp.local_debug_run() From f1b14e4251fdc3ee9ac927383756a3218e726be6 Mon Sep 17 00:00:00 2001 From: quzha Date: Tue, 6 Apr 2021 09:59:52 +0800 Subject: [PATCH 8/8] update doc --- docs/en_US/NAS/retiarii/Tutorial.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/en_US/NAS/retiarii/Tutorial.rst b/docs/en_US/NAS/retiarii/Tutorial.rst index cb527d9b78..ca5821f136 100644 --- a/docs/en_US/NAS/retiarii/Tutorial.rst +++ b/docs/en_US/NAS/retiarii/Tutorial.rst @@ -197,7 +197,7 @@ After all the above are prepared, it is time to start an experiment to do the mo The complete code of a simple MNIST example can be found :githublink:`here `. -**Local Debug Mode**: When running an experiment, it is easy to get some trivial errors in trial code, such as shape mismatch, undefined variable. To quickly fix these kinds of errors, we provide local debug mode which locally applies mutators once and runs only that generated model. To use local debug mode, users can simply replace ``exp.run(exp_config, 8081)`` in above code snippet with ``exp.local_debug_run()``. +**Local Debug Mode**: When running an experiment, it is easy to get some trivial errors in trial code, such as shape mismatch, undefined variable. To quickly fix these kinds of errors, we provide local debug mode which locally applies mutators once and runs only that generated model. To use local debug mode, users can simply invoke the API `debug_mutated_model(base_model, trainer, applied_mutators)`. Visualize the Experiment ------------------------