Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Add idompotent support for get_parameters() in nni sdk (#216)
Browse files Browse the repository at this point in the history
* Updated based on comments

* Fix bug, make get_parameters() idompotent

* Add idompotent support for get_parameters() in LocalTrainingService
  • Loading branch information
yds05 authored Oct 16, 2018
1 parent 14fac16 commit 9bb479b
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 6 deletions.
10 changes: 10 additions & 0 deletions src/nni_manager/core/nnimanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ class NNIManager implements Manager {
await this.storeExperimentProfile();
this.log.debug('Setup tuner...');

// Set up multiphase config
if(expParams.multiPhase && this.trainingService.isMultiPhaseJobSupported) {
this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString());
}

const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase);
console.log(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner(
Expand All @@ -140,6 +145,11 @@ class NNIManager implements Manager {
this.experimentProfile = await this.dataStore.getExperimentProfile(experimentId);
const expParams: ExperimentParams = this.experimentProfile.params;

// Set up multiphase config
if(expParams.multiPhase && this.trainingService.isMultiPhaseJobSupported) {
this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString());
}

const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase);
console.log(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ export enum TrialConfigMetadataKey {
MACHINE_LIST = 'machine_list',
TRIAL_CONFIG = 'trial_config',
EXPERIMENT_ID = 'experimentId',
MULTI_PHASE = 'multiPhase',
RANDOM_SCHEDULER = 'random_scheduler',
PAI_CLUSTER_CONFIG = 'pai_config'
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class LocalTrainingService implements TrainingService {
private trialSequenceId: number;
protected log: Logger;
protected localTrailConfig?: TrialConfig;
private isMultiPhase: boolean = false;

constructor() {
this.eventEmitter = new EventEmitter();
Expand Down Expand Up @@ -237,7 +238,7 @@ class LocalTrainingService implements TrainingService {
* Is multiphase job supported in current training service
*/
public get isMultiPhaseJobSupported(): boolean {
return false;
return true;
}

public async cancelTrialJob(trialJobId: string): Promise<void> {
Expand Down Expand Up @@ -270,6 +271,9 @@ class LocalTrainingService implements TrainingService {
throw new Error('trial config parsed failed');
}
break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default:
}
}
Expand Down Expand Up @@ -304,7 +308,8 @@ class LocalTrainingService implements TrainingService {
{ key: 'NNI_PLATFORM', value: 'local' },
{ key: 'NNI_SYS_DIR', value: trialJobDetail.workingDirectory },
{ key: 'NNI_TRIAL_JOB_ID', value: trialJobDetail.id },
{ key: 'NNI_OUTPUT_DIR', value: trialJobDetail.workingDirectory }
{ key: 'NNI_OUTPUT_DIR', value: trialJobDetail.workingDirectory },
{ key: 'MULTI_PHASE', value: this.isMultiPhase.toString() }
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

'use strict';

import { Client } from 'ssh2';
import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { GPUSummary } from '../common/gpuData';

Expand Down Expand Up @@ -112,6 +111,7 @@ export enum ScheduleResultType {
export const REMOTEMACHINE_RUN_SHELL_FORMAT: string =
`#!/bin/bash
export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_TRIAL_JOB_ID={1} NNI_OUTPUT_DIR={0}
export MULTI_PHASE={7}
cd $NNI_SYS_DIR
echo $$ >{2}
eval {3}{4} 2>{5}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class RemoteMachineTrainingService implements TrainingService {
private stopping: boolean = false;
private metricsEmitter: EventEmitter;
private log: Logger;
private isMultiPhase: boolean = false;
private trialSequenceId: number;

constructor(@component.Inject timer: ObservableTimer) {
Expand Down Expand Up @@ -226,7 +227,7 @@ class RemoteMachineTrainingService implements TrainingService {
* Is multiphase job supported in current training service
*/
public get isMultiPhaseJobSupported(): boolean {
return false;
return true;
}

/**
Expand Down Expand Up @@ -295,6 +296,9 @@ class RemoteMachineTrainingService implements TrainingService {
}
this.trialConfig = remoteMachineTrailConfig;
break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default:
//Reject for unknown keys
throw new Error(`Uknown key: ${key}`);
Expand Down Expand Up @@ -457,7 +461,9 @@ class RemoteMachineTrainingService implements TrainingService {
`CUDA_VISIBLE_DEVICES=${cuda_visible_device} ` : `CUDA_VISIBLE_DEVICES=" " `,
this.trialConfig.command,
path.join(trialWorkingFolder, 'stderr'),
path.join(trialWorkingFolder, '.nni', 'code'));
path.join(trialWorkingFolder, '.nni', 'code'),
/** Mark if the trial is multi-phase job */
this.isMultiPhase);

//create tmp trial working folder locally.
await cpp.exec(`mkdir -p ${path.join(trialLocalTempFolder, '.nni')}`);
Expand Down
10 changes: 9 additions & 1 deletion src/sdk/pynni/nni/platform/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
_log_file_path = os.path.join(_outputdir, 'trial.log')
init_logger(_log_file_path)

_multiphase = os.environ.get('MULTI_PHASE')

_param_index = 0

def request_next_parameter():
Expand All @@ -49,7 +51,13 @@ def request_next_parameter():

def get_parameters():
global _param_index
params_filepath = os.path.join(_sysdir, ('parameter_{}.cfg'.format(_param_index), 'parameter.cfg')[_param_index == 0])
params_file_name = ''
if _multiphase and (_multiphase == 'true' or _multiphase == 'True'):
params_file_name = ('parameter_{}.cfg'.format(_param_index), 'parameter.cfg')[_param_index == 0]
else:
params_file_name = 'parameter.cfg'

params_filepath = os.path.join(_sysdir, params_file_name)
if not os.path.isfile(params_filepath):
request_next_parameter()
while not os.path.isfile(params_filepath):
Expand Down

0 comments on commit 9bb479b

Please sign in to comment.