Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Dev enas - multi-phase hyper parameters support #96

Merged
merged 7 commits into from
Sep 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ before_install:
- sudo sh -c 'PATH=/usr/local/node/bin:$PATH yarn global add serve'
install:
- make
- make install
- make dev-install
- export PATH=$HOME/.nni/bin:$PATH
before_script:
- cd test/naive
Expand Down
4 changes: 2 additions & 2 deletions src/nni_manager/common/datastore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import { ExperimentProfile, TrialJobStatistics } from './manager';
import { TrialJobDetail, TrialJobStatus } from './trainingService';

type TrialJobEvent = TrialJobStatus | 'USER_TO_CANCEL' | 'ADD_CUSTOMIZED';
type TrialJobEvent = TrialJobStatus | 'USER_TO_CANCEL' | 'ADD_CUSTOMIZED' | 'ADD_HYPERPARAMETER';
type MetricType = 'PERIODICAL' | 'FINAL' | 'CUSTOM';

interface ExperimentProfileRecord {
Expand Down Expand Up @@ -62,7 +62,7 @@ interface TrialJobInfo {
status: TrialJobStatus;
startTime?: number;
endTime?: number;
hyperParameters?: string;
hyperParameters?: string[];
logPath?: string;
finalMetricData?: string;
stderrPath?: string;
Expand Down
1 change: 1 addition & 0 deletions src/nni_manager/common/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ interface ExperimentParams {
maxExecDuration: number; //seconds
maxTrialNum: number;
searchSpace: string;
multiPhase?: boolean;
tuner: {
className: string;
builtinTunerName?: string;
Expand Down
9 changes: 7 additions & 2 deletions src/nni_manager/common/trainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,16 @@ interface JobApplicationForm {
readonly jobType: JobType;
}

interface HyperParameters {
readonly value: string;
readonly index: number;
}

/**
* define TrialJobApplicationForm
*/
interface TrialJobApplicationForm extends JobApplicationForm {
readonly hyperParameters: string;
readonly hyperParameters: HyperParameters;
}

/**
Expand Down Expand Up @@ -116,6 +121,6 @@ abstract class TrainingService {

export {
TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
HostJobApplicationForm, JobApplicationForm, JobType
};
5 changes: 4 additions & 1 deletion src/nni_manager/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,11 @@ function parseArg(names: string[]): string {
* @param assessor: similiar as tuner
*
*/
function getMsgDispatcherCommand(tuner: any, assessor: any): string {
function getMsgDispatcherCommand(tuner: any, assessor: any, multiPhase: boolean = false): string {
let command: string = `python3 -m nni --tuner_class_name ${tuner.className}`;
if (multiPhase) {
command += ' --multi_phase';
}

if (process.env.VIRTUAL_ENV) {
command = path.join(process.env.VIRTUAL_ENV, 'bin/') +command;
Expand Down
5 changes: 4 additions & 1 deletion src/nni_manager/core/commands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const TRIAL_END = 'EN';
const TERMINATE = 'TE';

const NEW_TRIAL_JOB = 'TR';
const SEND_TRIAL_JOB_PARAMETER = 'SP';
const NO_MORE_TRIAL_JOBS = 'NO';
const KILL_TRIAL_JOB = 'KI';

Expand All @@ -39,6 +40,7 @@ const TUNER_COMMANDS: Set<string> = new Set([
TERMINATE,

NEW_TRIAL_JOB,
SEND_TRIAL_JOB_PARAMETER,
NO_MORE_TRIAL_JOBS
]);

Expand All @@ -63,5 +65,6 @@ export {
NO_MORE_TRIAL_JOBS,
KILL_TRIAL_JOB,
TUNER_COMMANDS,
ASSESSOR_COMMANDS
ASSESSOR_COMMANDS,
SEND_TRIAL_JOB_PARAMETER
};
30 changes: 26 additions & 4 deletions src/nni_manager/core/nniDataStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class NNIDataStore implements DataStore {
}

public async storeMetricData(trialJobId: string, data: string): Promise<void> {
this.log.debug(`storeMetricData: trialJobId: ${trialJobId}, data: ${data}`);
const metrics = JSON.parse(data) as MetricData;
assert(trialJobId === metrics.trial_job_id);
await this.db.storeMetricData(trialJobId, JSON.stringify({
Expand Down Expand Up @@ -168,18 +169,34 @@ class NNIDataStore implements DataStore {
}
}

private getJobStatusByLatestEvent(event: TrialJobEvent): TrialJobStatus {
private getJobStatusByLatestEvent(oldStatus: TrialJobStatus, event: TrialJobEvent): TrialJobStatus {
switch (event) {
case 'USER_TO_CANCEL':
return 'USER_CANCELED';
case 'ADD_CUSTOMIZED':
return 'WAITING';
case 'ADD_HYPERPARAMETER':
return oldStatus;
default:
}

return <TrialJobStatus>event;
}

private mergeHyperParameters(hyperParamList: string[], newParamStr: string): string[] {
const mergedHyperParams: any[] = [];
const newParam: any = JSON.parse(newParamStr);
for (const hyperParamStr of hyperParamList) {
const hyperParam: any = JSON.parse(hyperParamStr);
mergedHyperParams.push(hyperParam);
}
if (mergedHyperParams.filter((value: any) => { return value.parameter_index === newParam.parameter_index; }).length <= 0) {
mergedHyperParams.push(newParam);
}

return mergedHyperParams.map<string>((value: any) => { return JSON.stringify(value); });
}

private getTrialJobsByReplayEvents(trialJobEvents: TrialJobEventRecord[]): Map<string, TrialJobInfo> {
const map: Map<string, TrialJobInfo> = new Map();
// assume data is stored by time ASC order
Expand All @@ -193,7 +210,8 @@ class NNIDataStore implements DataStore {
} else {
jobInfo = {
id: record.trialJobId,
status: this.getJobStatusByLatestEvent(record.event)
status: this.getJobStatusByLatestEvent('UNKNOWN', record.event),
hyperParameters: []
};
}
if (!jobInfo) {
Expand Down Expand Up @@ -222,9 +240,13 @@ class NNIDataStore implements DataStore {
}
default:
}
jobInfo.status = this.getJobStatusByLatestEvent(record.event);
jobInfo.status = this.getJobStatusByLatestEvent(jobInfo.status, record.event);
if (record.data !== undefined && record.data.trim().length > 0) {
jobInfo.hyperParameters = record.data;
if (jobInfo.hyperParameters !== undefined) {
jobInfo.hyperParameters = this.mergeHyperParameters(jobInfo.hyperParameters, record.data);
} else {
assert(false, 'jobInfo.hyperParameters is undefined');
}
}
map.set(record.trialJobId, jobInfo);
}
Expand Down
27 changes: 23 additions & 4 deletions src/nni_manager/core/nnimanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import {
import { delay , getLogDir, getMsgDispatcherCommand} from '../common/utils';
import {
ADD_CUSTOMIZED_TRIAL_JOB, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, REPORT_METRIC_DATA,
REQUEST_TRIAL_JOBS, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE
REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE
} from './commands';
import { createDispatcherInterface, IpcInterface } from './ipcInterface';
import { TrialJobMaintainerEvent, TrialJobs } from './trialJobs';
Expand Down Expand Up @@ -116,7 +116,7 @@ class NNIManager implements Manager {
await this.storeExperimentProfile();
this.log.debug('Setup tuner...');

const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor);
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase);
console.log(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner(
//expParams.tuner.tunerCommand,
Expand All @@ -140,7 +140,7 @@ class NNIManager implements Manager {
this.experimentProfile = await this.dataStore.getExperimentProfile(experimentId);
const expParams: ExperimentParams = this.experimentProfile.params;

const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor);
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase);
console.log(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner(
dispatcherCommand,
Expand Down Expand Up @@ -460,7 +460,10 @@ class NNIManager implements Manager {
this.currSubmittedTrialNum++;
const trialJobAppForm: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: content
hyperParameters: {
value: content,
index: 0
}
};
const trialJobDetail: TrialJobDetail = await this.trainingService.submitTrialJob(trialJobAppForm);
this.trialJobsMaintainer.setTrialJob(trialJobDetail.id, Object.assign({}, trialJobDetail));
Expand All @@ -472,6 +475,22 @@ class NNIManager implements Manager {
}
}
break;
case SEND_TRIAL_JOB_PARAMETER:
const tunerCommand: any = JSON.parse(content);
assert(tunerCommand.parameter_index >= 0);
assert(tunerCommand.trial_job_id !== undefined);

const trialJobForm: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: {
value: content,
index: tunerCommand.parameter_index
}
};
await this.trainingService.updateTrialJob(tunerCommand.trial_job_id, trialJobForm);
await this.dataStore.storeTrialJobEvent(
'ADD_HYPERPARAMETER', tunerCommand.trial_job_id, content, undefined);
break;
case NO_MORE_TRIAL_JOBS:
this.trialJobsMaintainer.setNoMoreTrials();
break;
Expand Down
1 change: 1 addition & 0 deletions src/nni_manager/rest_server/restValidationSchemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export namespace ValidationSchemas {
trialConcurrency: joi.number().min(0).required(),
searchSpace: joi.string().required(),
maxExecDuration: joi.number().min(0).required(),
multiPhase: joi.boolean(),
tuner: joi.object({
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution'),
codeDir: joi.string(),
Expand Down
27 changes: 20 additions & 7 deletions src/nni_manager/training_service/local/localTrainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ import { getLogger, Logger } from '../../common/log';
import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import {
HostJobApplicationForm, JobApplicationForm, TrainingService, TrialJobApplicationForm,
HostJobApplicationForm, JobApplicationForm, HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../../common/trainingService';
import { delay, getExperimentRootDir, uniqueString } from '../../common/utils';
import { file } from 'tmp';

const tkill = require('tree-kill');

Expand Down Expand Up @@ -210,8 +211,18 @@ class LocalTrainingService implements TrainingService {
* @param trialJobId trial job id
* @param form job application form
*/
public updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
throw new MethodNotImplementedError();
public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.jobMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
if (form.jobType === 'TRIAL') {
await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>form).hyperParameters);
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}

return trialJobDetail;
}

/**
Expand Down Expand Up @@ -332,10 +343,7 @@ class LocalTrainingService implements TrainingService {
await cpp.exec(`mkdir -p ${path.join(trialJobDetail.workingDirectory, '.nni')}`);
await cpp.exec(`touch ${path.join(trialJobDetail.workingDirectory, '.nni', 'metrics')}`);
await fs.promises.writeFile(path.join(trialJobDetail.workingDirectory, 'run.sh'), runScriptLines.join('\n'), { encoding: 'utf8' });
await fs.promises.writeFile(
path.join(trialJobDetail.workingDirectory, 'parameter.cfg'),
(<TrialJobApplicationForm>trialJobDetail.form).hyperParameters,
{ encoding: 'utf8' });
await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>trialJobDetail.form).hyperParameters);
const process: cp.ChildProcess = cp.exec(`bash ${path.join(trialJobDetail.workingDirectory, 'run.sh')}`);

this.setTrialJobStatus(trialJobDetail, 'RUNNING');
Expand Down Expand Up @@ -402,6 +410,11 @@ class LocalTrainingService implements TrainingService {
}
}
}

private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
const filepath: string = path.join(directory, `parameter_${hyperParameters.index}.cfg`);
await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
}
}

export { LocalTrainingService };
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import { ObservableTimer } from '../../common/observableTimer';
import {
HostJobApplicationForm, JobApplicationForm, TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService';
import { delay, getExperimentRootDir, uniqueString } from '../../common/utils';
import { GPUSummary } from '../common/gpuData';
Expand Down Expand Up @@ -198,8 +198,24 @@ class RemoteMachineTrainingService implements TrainingService {
* @param trialJobId trial job id
* @param form job application form
*/
public updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
throw new MethodNotImplementedError();
public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
this.log.info(`updateTrialJob: form: ${JSON.stringify(form)}`);
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
if (form.jobType === 'TRIAL') {
const rmMeta: RemoteMachineMeta | undefined = (<RemoteMachineTrialJobDetail>trialJobDetail).rmMeta;
if (rmMeta !== undefined) {
await this.writeParameterFile(trialJobId, (<TrialJobApplicationForm>form).hyperParameters, rmMeta);
} else {
throw new Error(`updateTrialJob failed: ${trialJobId} rmMeta not found`);
}
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}

return trialJobDetail;
}

/**
Expand Down Expand Up @@ -442,15 +458,13 @@ class RemoteMachineTrainingService implements TrainingService {
//create tmp trial working folder locally.
await cpp.exec(`mkdir -p ${trialLocalTempFolder}`);

// Write file content ( run.sh and parameter.cfg ) to local tmp files
// Write file content ( run.sh and parameter_0.cfg ) to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), runScriptContent, { encoding: 'utf8' });
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'parameter.cfg'), form.hyperParameters, { encoding: 'utf8' });

// Copy local tmp files to remote machine
await SSHClientUtility.copyFileToRemote(
path.join(trialLocalTempFolder, 'run.sh'), path.join(trialWorkingFolder, 'run.sh'), sshClient);
await SSHClientUtility.copyFileToRemote(
path.join(trialLocalTempFolder, 'parameter.cfg'), path.join(trialWorkingFolder, 'parameter.cfg'), sshClient);
await this.writeParameterFile(trialJobId, form.hyperParameters, rmScheduleInfo.rmMeta);

// Copy files in codeDir to remote working directory
await SSHClientUtility.copyDirectoryToRemote(this.trialConfig.codeDir, trialWorkingFolder, sshClient);
Expand Down Expand Up @@ -562,6 +576,22 @@ class RemoteMachineTrainingService implements TrainingService {

return jobpidPath;
}

private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters, rmMeta: RemoteMachineMeta): Promise<void> {
const sshClient: Client | undefined = this.machineSSHClientMap.get(rmMeta);
if (sshClient === undefined) {
throw new Error('sshClient is undefined.');
}

const trialWorkingFolder: string = path.join(this.remoteExpRootDir, 'trials', trialJobId);
const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId);

const fileName: string = `parameter_${hyperParameters.index}.cfg`;
const localFilepath: string = path.join(trialLocalTempFolder, fileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });

await SSHClientUtility.copyFileToRemote(localFilepath, path.join(trialWorkingFolder, fileName), sshClient);
}
}

export { RemoteMachineTrainingService };
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ describe('Unit Test for RemoteMachineTrainingService', () => {
TrialConfigMetadataKey.TRIAL_CONFIG, `{"command":"sleep 1h && echo ","codeDir":"${localCodeDir}","gpuNum":1}`);
const form: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: 'mock hyperparameters'
hyperParameters: {
value: 'mock hyperparameters',
index: 0
}
};
const trialJob = await remoteMachineTrainingService.submitTrialJob(form);

Expand Down Expand Up @@ -135,7 +138,10 @@ describe('Unit Test for RemoteMachineTrainingService', () => {
// submit job
const form: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: 'mock hyperparameters'
hyperParameters: {
value: 'mock hyperparameters',
index: 0
}
};
const jobDetail: TrialJobDetail = await remoteMachineTrainingService.submitTrialJob(form);
// Add metrics listeners
Expand Down
Loading