Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Multi-phase training service #148

Merged
merged 14 commits into from
Oct 8, 2018
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ before_install:
- sudo sh -c 'PATH=/usr/local/node/bin:$PATH yarn global add serve'
install:
- make
- make install
- make easy-install
- export PATH=$HOME/.nni/bin:$PATH
before_script:
- cd test/naive
Expand Down
6 changes: 3 additions & 3 deletions src/nni_manager/common/datastore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import { ExperimentProfile, TrialJobStatistics } from './manager';
import { TrialJobDetail, TrialJobStatus } from './trainingService';

type TrialJobEvent = TrialJobStatus | 'USER_TO_CANCEL' | 'ADD_CUSTOMIZED';
type MetricType = 'PERIODICAL' | 'FINAL' | 'CUSTOM';
type TrialJobEvent = TrialJobStatus | 'USER_TO_CANCEL' | 'ADD_CUSTOMIZED' | 'ADD_HYPERPARAMETER';
type MetricType = 'PERIODICAL' | 'FINAL' | 'CUSTOM' | 'REQUEST_PARAMETER';

interface ExperimentProfileRecord {
readonly timestamp: number;
Expand Down Expand Up @@ -62,7 +62,7 @@ interface TrialJobInfo {
status: TrialJobStatus;
startTime?: number;
endTime?: number;
hyperParameters?: string;
hyperParameters?: string[];
logPath?: string;
finalMetricData?: string;
stderrPath?: string;
Expand Down
1 change: 1 addition & 0 deletions src/nni_manager/common/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ interface ExperimentParams {
maxExecDuration: number; //seconds
maxTrialNum: number;
searchSpace: string;
multiPhase?: boolean;
tuner: {
className: string;
builtinTunerName?: string;
Expand Down
9 changes: 7 additions & 2 deletions src/nni_manager/common/trainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,16 @@ interface JobApplicationForm {
readonly jobType: JobType;
}

interface HyperParameters {
readonly value: string;
readonly index: number;
}

/**
* define TrialJobApplicationForm
*/
interface TrialJobApplicationForm extends JobApplicationForm {
readonly hyperParameters: string;
readonly hyperParameters: HyperParameters;
}

/**
Expand Down Expand Up @@ -116,6 +121,6 @@ abstract class TrainingService {

export {
TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
HostJobApplicationForm, JobApplicationForm, JobType
};
5 changes: 4 additions & 1 deletion src/nni_manager/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,11 @@ function parseArg(names: string[]): string {
* @param assessor: similiar as tuner
*
*/
function getMsgDispatcherCommand(tuner: any, assessor: any): string {
function getMsgDispatcherCommand(tuner: any, assessor: any, multiPhase: boolean = false): string {
let command: string = `python3 -m nni --tuner_class_name ${tuner.className}`;
if (multiPhase) {
command += ' --multi_phase';
}

if (tuner.classArgs !== undefined) {
command += ` --tuner_args ${JSON.stringify(JSON.stringify(tuner.classArgs))}`;
Expand Down
5 changes: 4 additions & 1 deletion src/nni_manager/core/commands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const TRIAL_END = 'EN';
const TERMINATE = 'TE';

const NEW_TRIAL_JOB = 'TR';
const SEND_TRIAL_JOB_PARAMETER = 'SP';
const NO_MORE_TRIAL_JOBS = 'NO';
const KILL_TRIAL_JOB = 'KI';

Expand All @@ -39,6 +40,7 @@ const TUNER_COMMANDS: Set<string> = new Set([
TERMINATE,

NEW_TRIAL_JOB,
SEND_TRIAL_JOB_PARAMETER,
NO_MORE_TRIAL_JOBS
]);

Expand All @@ -63,5 +65,6 @@ export {
NO_MORE_TRIAL_JOBS,
KILL_TRIAL_JOB,
TUNER_COMMANDS,
ASSESSOR_COMMANDS
ASSESSOR_COMMANDS,
SEND_TRIAL_JOB_PARAMETER
};
62 changes: 53 additions & 9 deletions src/nni_manager/core/nniDataStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import * as component from '../common/component';
import { Database, DataStore, MetricData, MetricDataRecord, MetricType,
TrialJobEvent, TrialJobEventRecord, TrialJobInfo } from '../common/datastore';
import { isNewExperiment } from '../common/experimentStartupInfo';
import { getExperimentId } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log';
import { ExperimentProfile, TrialJobStatistics } from '../common/manager';
import { TrialJobStatus } from '../common/trainingService';
Expand All @@ -35,6 +36,7 @@ class NNIDataStore implements DataStore {
private db: Database = component.get(Database);
private log: Logger = getLogger();
private initTask!: Deferred<void>;
private multiPhase: boolean | undefined;

public init(): Promise<void> {
if (this.initTask !== undefined) {
Expand Down Expand Up @@ -112,13 +114,19 @@ class NNIDataStore implements DataStore {
}

public async getTrialJob(trialJobId: string): Promise<TrialJobInfo> {
const trialJobs = await this.queryTrialJobs(undefined, trialJobId);
const trialJobs: TrialJobInfo[] = await this.queryTrialJobs(undefined, trialJobId);

return trialJobs[0];
}

public async storeMetricData(trialJobId: string, data: string): Promise<void> {
const metrics = JSON.parse(data) as MetricData;
const metrics: MetricData = JSON.parse(data);
// REQUEST_PARAMETER is used to request new parameters for multiphase trial job,
// it is not metrics, so it is skipped here.
if (metrics.type === 'REQUEST_PARAMETER') {
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved

return;
}
assert(trialJobId === metrics.trial_job_id);
await this.db.storeMetricData(trialJobId, JSON.stringify({
trialJobId: metrics.trial_job_id,
Expand Down Expand Up @@ -160,25 +168,56 @@ class NNIDataStore implements DataStore {

private async getFinalMetricData(trialJobId: string): Promise<any> {
const metrics: MetricDataRecord[] = await this.getMetricData(trialJobId, 'FINAL');
if (metrics.length > 1) {
this.log.error(`Found multiple final results for trial job: ${trialJobId}`);

const multiPhase: boolean = await this.isMultiPhase();

if (metrics.length > 1 && !multiPhase) {
this.log.error(`Found multiple FINAL results for trial job ${trialJobId}`);
}

return metrics[metrics.length - 1];
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
}

private async isMultiPhase(): Promise<boolean> {
if (this.multiPhase === undefined) {
this.multiPhase = (await this.getExperimentProfile(getExperimentId())).params.multiPhase;
}

return metrics[0];
if (this.multiPhase !== undefined) {
return this.multiPhase;
} else {
return false;
}
}

private getJobStatusByLatestEvent(event: TrialJobEvent): TrialJobStatus {
private getJobStatusByLatestEvent(oldStatus: TrialJobStatus, event: TrialJobEvent): TrialJobStatus {
switch (event) {
case 'USER_TO_CANCEL':
return 'USER_CANCELED';
case 'ADD_CUSTOMIZED':
return 'WAITING';
case 'ADD_HYPERPARAMETER':
return oldStatus;
default:
}

return <TrialJobStatus>event;
}

private mergeHyperParameters(hyperParamList: string[], newParamStr: string): string[] {
const mergedHyperParams: any[] = [];
const newParam: any = JSON.parse(newParamStr);
for (const hyperParamStr of hyperParamList) {
const hyperParam: any = JSON.parse(hyperParamStr);
mergedHyperParams.push(hyperParam);
}
if (mergedHyperParams.filter((value: any) => value.parameter_index === newParam.parameter_index).length <= 0) {
mergedHyperParams.push(newParam);
}

return mergedHyperParams.map<string>((value: any) => { return JSON.stringify(value); });
}

private getTrialJobsByReplayEvents(trialJobEvents: TrialJobEventRecord[]): Map<string, TrialJobInfo> {
const map: Map<string, TrialJobInfo> = new Map();
// assume data is stored by time ASC order
Expand All @@ -192,7 +231,8 @@ class NNIDataStore implements DataStore {
} else {
jobInfo = {
id: record.trialJobId,
status: this.getJobStatusByLatestEvent(record.event)
status: this.getJobStatusByLatestEvent('UNKNOWN', record.event),
hyperParameters: []
};
}
if (!jobInfo) {
Expand Down Expand Up @@ -221,9 +261,13 @@ class NNIDataStore implements DataStore {
}
default:
}
jobInfo.status = this.getJobStatusByLatestEvent(record.event);
jobInfo.status = this.getJobStatusByLatestEvent(jobInfo.status, record.event);
if (record.data !== undefined && record.data.trim().length > 0) {
jobInfo.hyperParameters = record.data;
if (jobInfo.hyperParameters !== undefined) {
jobInfo.hyperParameters = this.mergeHyperParameters(jobInfo.hyperParameters, record.data);
} else {
assert(false, 'jobInfo.hyperParameters is undefined');
}
}
map.set(record.trialJobId, jobInfo);
}
Expand Down
27 changes: 23 additions & 4 deletions src/nni_manager/core/nnimanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import {
import { delay , getLogDir, getMsgDispatcherCommand} from '../common/utils';
import {
ADD_CUSTOMIZED_TRIAL_JOB, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, REPORT_METRIC_DATA,
REQUEST_TRIAL_JOBS, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE
REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE
} from './commands';
import { createDispatcherInterface, IpcInterface } from './ipcInterface';
import { TrialJobMaintainerEvent, TrialJobs } from './trialJobs';
Expand Down Expand Up @@ -116,7 +116,7 @@ class NNIManager implements Manager {
await this.storeExperimentProfile();
this.log.debug('Setup tuner...');

const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor);
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase);
console.log(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner(
//expParams.tuner.tunerCommand,
Expand All @@ -140,7 +140,7 @@ class NNIManager implements Manager {
this.experimentProfile = await this.dataStore.getExperimentProfile(experimentId);
const expParams: ExperimentParams = this.experimentProfile.params;

const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor);
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase);
console.log(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner(
dispatcherCommand,
Expand Down Expand Up @@ -462,7 +462,10 @@ class NNIManager implements Manager {
this.currSubmittedTrialNum++;
const trialJobAppForm: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: content
hyperParameters: {
value: content,
index: 0
}
};
const trialJobDetail: TrialJobDetail = await this.trainingService.submitTrialJob(trialJobAppForm);
this.trialJobsMaintainer.setTrialJob(trialJobDetail.id, Object.assign({}, trialJobDetail));
Expand All @@ -474,6 +477,22 @@ class NNIManager implements Manager {
}
}
break;
case SEND_TRIAL_JOB_PARAMETER:
const tunerCommand: any = JSON.parse(content);
assert(tunerCommand.parameter_index >= 0);
assert(tunerCommand.trial_job_id !== undefined);

const trialJobForm: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: {
value: content,
index: tunerCommand.parameter_index
}
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
};
await this.trainingService.updateTrialJob(tunerCommand.trial_job_id, trialJobForm);
await this.dataStore.storeTrialJobEvent(
'ADD_HYPERPARAMETER', tunerCommand.trial_job_id, content, undefined);
break;
case NO_MORE_TRIAL_JOBS:
this.trialJobsMaintainer.setNoMoreTrials();
break;
Expand Down
1 change: 1 addition & 0 deletions src/nni_manager/rest_server/restValidationSchemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ export namespace ValidationSchemas {
trialConcurrency: joi.number().min(0).required(),
searchSpace: joi.string().required(),
maxExecDuration: joi.number().min(0).required(),
multiPhase: joi.boolean(),
tuner: joi.object({
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner'),
codeDir: joi.string(),
Expand Down
27 changes: 20 additions & 7 deletions src/nni_manager/training_service/local/localTrainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ import { getLogger, Logger } from '../../common/log';
import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import {
HostJobApplicationForm, JobApplicationForm, TrainingService, TrialJobApplicationForm,
HostJobApplicationForm, JobApplicationForm, HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../../common/trainingService';
import { delay, getExperimentRootDir, uniqueString } from '../../common/utils';
import { file } from 'tmp';

const tkill = require('tree-kill');

Expand Down Expand Up @@ -210,8 +211,18 @@ class LocalTrainingService implements TrainingService {
* @param trialJobId trial job id
* @param form job application form
*/
public updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
throw new MethodNotImplementedError();
public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.jobMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
if (form.jobType === 'TRIAL') {
await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>form).hyperParameters);
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}

return trialJobDetail;
}

/**
Expand Down Expand Up @@ -332,10 +343,7 @@ class LocalTrainingService implements TrainingService {
await cpp.exec(`mkdir -p ${path.join(trialJobDetail.workingDirectory, '.nni')}`);
await cpp.exec(`touch ${path.join(trialJobDetail.workingDirectory, '.nni', 'metrics')}`);
await fs.promises.writeFile(path.join(trialJobDetail.workingDirectory, 'run.sh'), runScriptLines.join('\n'), { encoding: 'utf8' });
await fs.promises.writeFile(
path.join(trialJobDetail.workingDirectory, 'parameter.cfg'),
(<TrialJobApplicationForm>trialJobDetail.form).hyperParameters,
{ encoding: 'utf8' });
await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>trialJobDetail.form).hyperParameters);
const process: cp.ChildProcess = cp.exec(`bash ${path.join(trialJobDetail.workingDirectory, 'run.sh')}`);

this.setTrialJobStatus(trialJobDetail, 'RUNNING');
Expand Down Expand Up @@ -402,6 +410,11 @@ class LocalTrainingService implements TrainingService {
}
}
}

private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
const filepath: string = path.join(directory, `parameter_${hyperParameters.index}.cfg`);
await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
}
}

export { LocalTrainingService };
Loading