Skip to content

Commit

Permalink
Merge pull request microsoft#148 from Microsoft/master
Browse files Browse the repository at this point in the history
merge master
  • Loading branch information
SparkSnail authored Mar 25, 2019
2 parents e1ae623 + 8fd18a5 commit ec41d56
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
48 changes: 29 additions & 19 deletions src/nni_manager/core/nniDataStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -250,29 +250,26 @@ class NNIDataStore implements DataStore {
return <TrialJobStatus>event;
}

private mergeHyperParameters(hyperParamList: string[], newParamStr: string): string[] {
const mergedHyperParams: any[] = [];
let newParam: any;
private parseHyperParameter(hParamStr: string): any {
let hParam: any;
try {
newParam = JSON.parse(newParamStr);
hParam = JSON.parse(hParamStr);

return hParam;
} catch (err) {
this.log.error(`Hyper parameter needs to be in json format: ${newParamStr}`);
this.log.error(`Hyper parameter needs to be in json format: ${hParamStr}`);

return hyperParamList;
}
for (const hyperParamStr of hyperParamList) {
const hyperParam: any = JSON.parse(hyperParamStr);
mergedHyperParams.push(hyperParam);
}
if (mergedHyperParams.filter((value: any) => value.parameter_index === newParam.parameter_index).length <= 0) {
mergedHyperParams.push(newParam);
return undefined;
}

return mergedHyperParams.map<string>((value: any) => { return JSON.stringify(value); });
}

// tslint:disable-next-line:cyclomatic-complexity
private getTrialJobsByReplayEvents(trialJobEvents: TrialJobEventRecord[]): Map<string, TrialJobInfo> {
this.log.debug('getTrialJobsByReplayEvents begin');

const map: Map<string, TrialJobInfo> = new Map();
const hParamIdMap: Map<string, Set<number>> = new Map();

// assume data is stored by time ASC order
for (const record of trialJobEvents) {
let jobInfo: TrialJobInfo | undefined;
Expand Down Expand Up @@ -322,10 +319,21 @@ class NNIDataStore implements DataStore {
}
jobInfo.status = this.getJobStatusByLatestEvent(jobInfo.status, record.event);
if (record.data !== undefined && record.data.trim().length > 0) {
if (jobInfo.hyperParameters !== undefined) {
jobInfo.hyperParameters = this.mergeHyperParameters(jobInfo.hyperParameters, record.data);
} else {
assert(false, 'jobInfo.hyperParameters is undefined');
const newHParam: any = this.parseHyperParameter(record.data);
if (newHParam !== undefined) {
if (jobInfo.hyperParameters !== undefined) {
let hParamIds: Set<number> | undefined = hParamIdMap.get(jobInfo.id);
if (hParamIds === undefined) {
hParamIds = new Set();
}
if (!hParamIds.has(newHParam.parameter_index)) {
jobInfo.hyperParameters.push(JSON.stringify(newHParam));
hParamIds.add(newHParam.parameter_index);
hParamIdMap.set(jobInfo.id, hParamIds);
}
} else {
assert(false, 'jobInfo.hyperParameters is undefined');
}
}
}
if (record.sequenceId !== undefined && jobInfo.sequenceId === undefined) {
Expand All @@ -334,6 +342,8 @@ class NNIDataStore implements DataStore {
map.set(record.trialJobId, jobInfo);
}

this.log.debug('getTrialJobsByReplayEvents done');

return map;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ class Hyperband(MsgDispatcherBase):
"""
def __init__(self, R, eta=3, optimize_mode='maximize'):
"""B = (s_max + 1)R"""
super()
super(Hyperband, self).__init__()
self.R = R # pylint: disable=invalid-name
self.eta = eta
self.brackets = dict() # dict of Bracket
Expand Down

0 comments on commit ec41d56

Please sign in to comment.