Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Fix localTrainingService stream (#885)
Browse files Browse the repository at this point in the history
  • Loading branch information
SparkSnail authored Mar 26, 2019
1 parent 892c9c4 commit bd34681
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
21 changes: 16 additions & 5 deletions src/nni_manager/training_service/local/localTrainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class LocalTrainingService implements TrainingService {
protected log: Logger;
protected localTrailConfig?: TrialConfig;
private isMultiPhase: boolean = false;
private streams: Array<ts.Stream>;
protected jobStreamMap: Map<string, ts.Stream>;

constructor() {
this.eventEmitter = new EventEmitter();
Expand All @@ -113,7 +113,7 @@ class LocalTrainingService implements TrainingService {
this.stopping = false;
this.log = getLogger();
this.trialSequenceId = -1;
this.streams = new Array<ts.Stream>();
this.jobStreamMap = new Map<string, ts.Stream>();
this.log.info('Construct local machine training service.');
}

Expand Down Expand Up @@ -307,14 +307,24 @@ class LocalTrainingService implements TrainingService {
public cleanUp(): Promise<void> {
this.log.info('Stopping local machine training service...');
this.stopping = true;
for (const stream of this.streams) {
for (const stream of this.jobStreamMap.values()) {
stream.destroy();
}
return Promise.resolve();
}

protected onTrialJobStatusChanged(trialJob: TrialJobDetail, oldStatus: TrialJobStatus): void {
//abstract
//if job is not running, destory job stream
if(['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED'].includes(trialJob.status)) {
if(this.jobStreamMap.has(trialJob.id)) {
const stream = this.jobStreamMap.get(trialJob.id);
if(!stream) {
throw new Error(`Could not find stream in trial ${trialJob.id}`);
}
stream.destroy();
this.jobStreamMap.delete(trialJob.id);
}
}
}

protected getEnvironmentVariables(trialJobDetail: TrialJobDetail, _: {}): { key: string; value: string }[] {
Expand Down Expand Up @@ -396,7 +406,8 @@ class LocalTrainingService implements TrainingService {
buffer = remain;
}
});
this.streams.push(stream);

this.jobStreamMap.set(trialJobDetail.id, stream);
}

private async runHostJob(form: HostJobApplicationForm): Promise<TrialJobDetail> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class LocalTrainingServiceForGPU extends LocalTrainingService {
}

protected onTrialJobStatusChanged(trialJob: LocalTrialJobDetailForGPU, oldStatus: TrialJobStatus): void {
super.onTrialJobStatusChanged(trialJob, oldStatus);
if (trialJob.gpuIndices !== undefined && trialJob.gpuIndices.length !== 0 && this.gpuScheduler !== undefined) {
if (oldStatus === 'RUNNING' && trialJob.status !== 'RUNNING') {
for (const index of trialJob.gpuIndices) {
Expand Down

0 comments on commit bd34681

Please sign in to comment.