Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Add native support for v2 config (#3466)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhe-lz authored Apr 9, 2021
1 parent 6aaca5f commit 817ec68
Show file tree
Hide file tree
Showing 69 changed files with 1,561 additions and 2,146 deletions.
28 changes: 13 additions & 15 deletions nni/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import base64

from .runtime.common import enable_multi_thread, enable_multi_phase
from .runtime.common import enable_multi_thread
from .runtime.msg_dispatcher import MsgDispatcher
from .tools.package_utils import create_builtin_class_instance, create_customized_class_instance

Expand All @@ -29,10 +29,8 @@ def main():
exp_params = json.loads(exp_params_decode)
logger.debug('exp_params json obj: [%s]', json.dumps(exp_params, indent=4))

if exp_params.get('multiThread'):
if exp_params.get('deprecated', {}).get('multiThread'):
enable_multi_thread()
if exp_params.get('multiPhase'):
enable_multi_phase()

if exp_params.get('advisor') is not None:
# advisor is enabled and starts to run
Expand Down Expand Up @@ -61,10 +59,10 @@ def main():


def _run_advisor(exp_params):
if exp_params.get('advisor').get('builtinAdvisorName'):
if exp_params.get('advisor').get('name'):
dispatcher = create_builtin_class_instance(
exp_params.get('advisor').get('builtinAdvisorName'),
exp_params.get('advisor').get('classArgs'),
exp_params['advisor']['name'],
exp_params['advisor'].get('classArgs'),
'advisors')
else:
dispatcher = create_customized_class_instance(exp_params.get('advisor'))
Expand All @@ -78,26 +76,26 @@ def _run_advisor(exp_params):


def _create_tuner(exp_params):
if exp_params.get('tuner').get('builtinTunerName'):
if exp_params['tuner'].get('name'):
tuner = create_builtin_class_instance(
exp_params.get('tuner').get('builtinTunerName'),
exp_params.get('tuner').get('classArgs'),
exp_params['tuner']['name'],
exp_params['tuner'].get('classArgs'),
'tuners')
else:
tuner = create_customized_class_instance(exp_params.get('tuner'))
tuner = create_customized_class_instance(exp_params['tuner'])
if tuner is None:
raise AssertionError('Failed to create Tuner instance')
return tuner


def _create_assessor(exp_params):
if exp_params.get('assessor').get('builtinAssessorName'):
if exp_params['assessor'].get('name'):
assessor = create_builtin_class_instance(
exp_params.get('assessor').get('builtinAssessorName'),
exp_params.get('assessor').get('classArgs'),
exp_params['assessor']['name'],
exp_params['assessor'].get('classArgs'),
'assessors')
else:
assessor = create_customized_class_instance(exp_params.get('assessor'))
assessor = create_customized_class_instance(exp_params['assessor'])
if assessor is None:
raise AssertionError('Failed to create Assessor instance')
return assessor
Expand Down
1 change: 1 addition & 0 deletions nni/experiment/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .kubeflow import *
from .frameworkcontroller import *
from .adl import *
from .shared_storage import *
2 changes: 2 additions & 0 deletions nni/experiment/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def canonical(self: T) -> T:
elif isinstance(value, ConfigBase):
setattr(ret, key, value.canonical())
# value will be copied twice, should not be a performance issue anyway
elif isinstance(value, Path):
setattr(ret, key, str(value))
return ret

def validate(self) -> None:
Expand Down
30 changes: 22 additions & 8 deletions nni/experiment/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from ruamel.yaml import YAML

from .base import ConfigBase, PathLike
from . import util

Expand All @@ -27,23 +29,27 @@ def validate(self):
super().validate()
_validate_algo(self)


@dataclass(init=False)
class AlgorithmConfig(_AlgorithmConfig):
name: str
class_args: Optional[Dict[str, Any]] = None


@dataclass(init=False)
class CustomAlgorithmConfig(_AlgorithmConfig):
class_name: str
class_directory: Optional[PathLike] = None
class_directory: Optional[PathLike] = '.'
class_args: Optional[Dict[str, Any]] = None


class TrainingServiceConfig(ConfigBase):
platform: str

class SharedStorageConfig(ConfigBase):
storage_type: str
local_mount_point: str
remote_mount_point: str
local_mounted: str


@dataclass(init=False)
class ExperimentConfig(ConfigBase):
Expand All @@ -53,19 +59,21 @@ class ExperimentConfig(ConfigBase):
trial_command: str
trial_code_directory: PathLike = '.'
trial_concurrency: int
trial_gpu_number: Optional[int] = None
trial_gpu_number: Optional[int] = None # TODO: in openpai cannot be None
max_experiment_duration: Optional[str] = None
max_trial_number: Optional[int] = None
nni_manager_ip: Optional[str] = None
use_annotation: bool = False
debug: bool = False
log_level: Optional[str] = None
experiment_working_directory: Optional[PathLike] = None
experiment_working_directory: PathLike = '~/nni-experiments'
tuner_gpu_indices: Optional[Union[List[int], str]] = None
tuner: Optional[_AlgorithmConfig] = None
assessor: Optional[_AlgorithmConfig] = None
advisor: Optional[_AlgorithmConfig] = None
training_service: Union[TrainingServiceConfig, List[TrainingServiceConfig]]
shared_storage: Optional[SharedStorageConfig] = None
_deprecated: Optional[Dict[str, Any]] = None

def __init__(self, training_service_platform: Optional[Union[str, List[str]]] = None, **kwargs):
base_path = kwargs.pop('_base_path', None)
Expand Down Expand Up @@ -100,6 +108,12 @@ def validate(self, initialized_tuner: bool = False) -> None:
if self.training_service.use_active_gpu is None:
raise ValueError('Please set "use_active_gpu"')

def json(self) -> Dict[str, Any]:
obj = super().json()
if obj.get('searchSpaceFile'):
obj['searchSpace'] = YAML().load(open(obj.pop('searchSpaceFile')))
return obj

## End of public API ##

@property
Expand All @@ -117,9 +131,9 @@ def _validation_rules(self):
'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None,
'experiment_working_directory': util.canonical_path,
'tuner_gpu_indices': lambda value: [int(idx) for idx in value.split(',')] if isinstance(value, str) else value,
'tuner': lambda config: None if config is None or config.name == '_none_' else config,
'assessor': lambda config: None if config is None or config.name == '_none_' else config,
'advisor': lambda config: None if config is None or config.name == '_none_' else config,
'tuner': lambda config: None if config is None or config.name == '_none_' else config.canonical(),
'assessor': lambda config: None if config is None or config.name == '_none_' else config.canonical(),
'advisor': lambda config: None if config is None or config.name == '_none_' else config.canonical(),
}

_validation_rules = {
Expand Down
Loading

0 comments on commit 817ec68

Please sign in to comment.