diff --git a/deployment/docker/Dockerfile b/deployment/docker/Dockerfile index b868f1c9fd..3db87f093a 100644 --- a/deployment/docker/Dockerfile +++ b/deployment/docker/Dockerfile @@ -18,7 +18,7 @@ # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 +FROM nvidia/cuda:9.0-cudnn7-runtime-ubuntu16.04 LABEL maintainer='Microsoft NNI Team' diff --git a/src/nni_manager/common/experimentStartupInfo.ts b/src/nni_manager/common/experimentStartupInfo.ts index 482be0c5e0..7e8c9e0307 100644 --- a/src/nni_manager/common/experimentStartupInfo.ts +++ b/src/nni_manager/common/experimentStartupInfo.ts @@ -26,15 +26,17 @@ import * as component from '../common/component'; class ExperimentStartupInfo { private experimentId: string = ''; private newExperiment: boolean = true; + private basePort: number = -1; private initialized: boolean = false; private initTrialSequenceID: number = 0; - public setStartupInfo(newExperiment: boolean, experimentId: string): void { + public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number): void { assert(!this.initialized); assert(experimentId.trim().length > 0); this.newExperiment = newExperiment; this.experimentId = experimentId; + this.basePort = basePort; this.initialized = true; } @@ -44,6 +46,12 @@ class ExperimentStartupInfo { return this.experimentId; } + public getBasePort(): number { + assert(this.initialized); + + return this.basePort; + } + public isNewExperiment(): boolean { assert(this.initialized); @@ -66,6 +74,10 @@ function getExperimentId(): string { return component.get(ExperimentStartupInfo).getExperimentId(); } +function getBasePort(): number { + return component.get(ExperimentStartupInfo).getBasePort(); +} + function isNewExperiment(): boolean { return component.get(ExperimentStartupInfo).isNewExperiment(); } @@ -78,9 +90,9 @@ function getInitTrialSequenceId(): number { return component.get(ExperimentStartupInfo).getInitTrialSequenceId(); } -function setExperimentStartupInfo(newExperiment: boolean, experimentId: string): void { - component.get(ExperimentStartupInfo).setStartupInfo(newExperiment, experimentId); +function setExperimentStartupInfo(newExperiment: boolean, experimentId: string, basePort: number): void { + component.get(ExperimentStartupInfo).setStartupInfo(newExperiment, experimentId, basePort); } -export { ExperimentStartupInfo, getExperimentId, isNewExperiment, +export { ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment, setExperimentStartupInfo, setInitTrialSequenceId, getInitTrialSequenceId }; diff --git a/src/nni_manager/common/restServer.ts b/src/nni_manager/common/restServer.ts index 7929e4344a..66320eefd8 100644 --- a/src/nni_manager/common/restServer.ts +++ b/src/nni_manager/common/restServer.ts @@ -19,10 +19,12 @@ 'use strict'; +import * as assert from 'assert'; import * as express from 'express'; import * as http from 'http'; import { Deferred } from 'ts-deferred'; import { getLogger, Logger } from './log'; +import { getBasePort } from './experimentStartupInfo'; /** * Abstraction class to create a RestServer @@ -39,13 +41,20 @@ export abstract class RestServer { protected port?: number; protected app: express.Application = express(); protected log: Logger = getLogger(); + protected basePort?: number; + constructor() { + this.port = getBasePort(); + assert(this.port && this.port > 1024); + } + get endPoint(): string { // tslint:disable-next-line:no-http-string return `http://${this.hostName}:${this.port}`; } - public start(port?: number, hostName?: string): Promise { + public start(hostName?: string): Promise { + this.log.info(`RestServer start`); if (this.startTask !== undefined) { return this.startTask.promise; } @@ -56,9 +65,8 @@ export abstract class RestServer { if (hostName) { this.hostName = hostName; } - if (port) { - this.port = port; - } + + this.log.info(`RestServer base port is ${this.port}`); this.server = this.app.listen(this.port as number, this.hostName).on('listening', () => { this.startTask.resolve(); diff --git a/src/nni_manager/common/utils.ts b/src/nni_manager/common/utils.ts index 36f0bf8c62..850b88a652 100644 --- a/src/nni_manager/common/utils.ts +++ b/src/nni_manager/common/utils.ts @@ -222,7 +222,7 @@ function prepareUnitTest(): void { Container.snapshot(TrainingService); Container.snapshot(Manager); - setExperimentStartupInfo(true, 'unittest'); + setExperimentStartupInfo(true, 'unittest', 8080); mkDirPSync(getLogDir()); const sqliteFile: string = path.join(getDefaultDatabaseDir(), 'nni.sqlite'); diff --git a/src/nni_manager/main.ts b/src/nni_manager/main.ts index b5e3e07c66..0476749855 100644 --- a/src/nni_manager/main.ts +++ b/src/nni_manager/main.ts @@ -39,10 +39,10 @@ import { import { PAITrainingService } from './training_service/pai/paiTrainingService' -function initStartupInfo(startExpMode: string, resumeExperimentId: string) { +function initStartupInfo(startExpMode: string, resumeExperimentId: string, basePort: number) { const createNew: boolean = (startExpMode === 'new'); const expId: string = createNew ? uniqueString(8) : resumeExperimentId; - setExperimentStartupInfo(createNew, expId); + setExperimentStartupInfo(createNew, expId, basePort); } async function initContainer(platformMode: string): Promise { @@ -93,14 +93,14 @@ if (startMode === 'resume' && experimentId.trim().length < 1) { process.exit(1); } -initStartupInfo(startMode, experimentId); +initStartupInfo(startMode, experimentId, port); mkDirP(getLogDir()).then(async () => { const log: Logger = getLogger(); try { await initContainer(mode); const restServer: NNIRestServer = component.get(NNIRestServer); - await restServer.start(port); + await restServer.start(); log.info(`Rest server listening on: ${restServer.endPoint}`); } catch (err) { log.error(`${err.stack}`); diff --git a/src/nni_manager/training_service/pai/paiData.ts b/src/nni_manager/training_service/pai/paiData.ts index 461af034e6..a9e3e898b6 100644 --- a/src/nni_manager/training_service/pai/paiData.ts +++ b/src/nni_manager/training_service/pai/paiData.ts @@ -62,8 +62,8 @@ fi`; export const PAI_TRIAL_COMMAND_FORMAT: string = `export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} && cd $NNI_SYS_DIR && sh install_nni.sh -&& python3 -m nni_trial_tool.trial_keeper --trial_command '{4}' --nnimanager_ip '{5}' --pai_hdfs_output_dir '{6}' ---pai_hdfs_host '{7}' --pai_user_name {8}`; +&& python3 -m nni_trial_tool.trial_keeper --trial_command '{4}' --nnimanager_ip '{5}' --nnimanager_port '{6}' +--pai_hdfs_output_dir '{7}' --pai_hdfs_host '{8}' --pai_user_name {9}`; export const PAI_OUTPUT_DIR_FORMAT: string = `hdfs://{0}:9000/`; diff --git a/src/nni_manager/training_service/pai/paiJobRestServer.ts b/src/nni_manager/training_service/pai/paiJobRestServer.ts index 098ea74333..4bba44da85 100644 --- a/src/nni_manager/training_service/pai/paiJobRestServer.ts +++ b/src/nni_manager/training_service/pai/paiJobRestServer.ts @@ -19,9 +19,11 @@ 'use strict'; +import * as assert from 'assert'; import { Request, Response, Router } from 'express'; import * as bodyParser from 'body-parser'; import * as component from '../../common/component'; +import { getBasePort } from '../../common/experimentStartupInfo'; import { getExperimentId } from '../../common/experimentStartupInfo'; import { Inject } from 'typescript-ioc'; import { PAITrainingService } from './paiTrainingService'; @@ -48,10 +50,20 @@ export class PAIJobRestServer extends RestServer{ */ constructor() { super(); - this.port = PAIJobRestServer.DEFAULT_PORT; + const basePort: number = getBasePort(); + assert(basePort && basePort > 1024); + + this.port = basePort + 1; // PAIJobRestServer.DEFAULT_PORT; this.paiTrainingService = component.get(PAITrainingService); } + public get paiRestServerPort(): number { + if(!this.port) { + throw new Error('PAI Rest server port is undefined'); + } + return this.port; + } + /** * NNIRestServer's own router registration */ diff --git a/src/nni_manager/training_service/pai/paiTrainingService.ts b/src/nni_manager/training_service/pai/paiTrainingService.ts index 55b1215271..013b568af8 100644 --- a/src/nni_manager/training_service/pai/paiTrainingService.ts +++ b/src/nni_manager/training_service/pai/paiTrainingService.ts @@ -68,6 +68,7 @@ class PAITrainingService implements TrainingService { private hdfsBaseDir: string | undefined; private hdfsOutputHost: string | undefined; private trialSequenceId: number; + private paiRestServerPort?: number; constructor() { this.log = getLogger(); @@ -145,6 +146,11 @@ class PAITrainingService implements TrainingService { throw new Error('hdfsOutputHost is not initialized'); } + if(!this.paiRestServerPort) { + const restServer: PAIJobRestServer = component.get(PAIJobRestServer); + this.paiRestServerPort = restServer.paiRestServerPort; + } + this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`); const trialJobId: string = uniqueString(5); @@ -200,6 +206,7 @@ class PAITrainingService implements TrainingService { this.experimentId, this.paiTrialConfig.command, getIPV4Address(), + this.paiRestServerPort, hdfsOutputDir, this.hdfsOutputHost, this.paiClusterConfig.userName diff --git a/tools/nni_trial_tool/constants.py b/tools/nni_trial_tool/constants.py index 1a199c0ccc..c767e03a89 100644 --- a/tools/nni_trial_tool/constants.py +++ b/tools/nni_trial_tool/constants.py @@ -24,8 +24,6 @@ BASE_URL = 'http://{}' -DEFAULT_REST_PORT = 51189 - HOME_DIR = os.path.join(os.environ['HOME'], 'nni') LOG_DIR = os.environ['NNI_OUTPUT_DIR'] diff --git a/tools/nni_trial_tool/metrics_reader.py b/tools/nni_trial_tool/metrics_reader.py index 79af81f118..5d688a1858 100644 --- a/tools/nni_trial_tool/metrics_reader.py +++ b/tools/nni_trial_tool/metrics_reader.py @@ -24,7 +24,7 @@ import re import requests -from .constants import BASE_URL, DEFAULT_REST_PORT +from .constants import BASE_URL from .rest_utils import rest_get, rest_post, rest_put, rest_delete from .url_utils import gen_update_metrics_url @@ -40,11 +40,10 @@ class TrialMetricsReader(): ''' Read metrics data from a trial job ''' - def __init__(self, rest_port = DEFAULT_REST_PORT): + def __init__(self): metrics_base_dir = os.path.join(NNI_SYS_DIR, '.nni') self.offset_filename = os.path.join(metrics_base_dir, 'metrics_offset') self.metrics_filename = os.path.join(metrics_base_dir, 'metrics') - self.rest_port = rest_port if not os.path.exists(metrics_base_dir): os.makedirs(metrics_base_dir) @@ -107,7 +106,7 @@ def read_trial_metrics(self): offset = self._get_offset() return self._read_all_available_records(offset) -def read_experiment_metrics(nnimanager_ip): +def read_experiment_metrics(nnimanager_ip, nnimanager_port): ''' Read metrics data for specified trial jobs ''' @@ -118,7 +117,7 @@ def read_experiment_metrics(nnimanager_ip): result['metrics'] = reader.read_trial_metrics() print('Result metrics is {}'.format(json.dumps(result))) if len(result['metrics']) > 0: - response = rest_post(gen_update_metrics_url(BASE_URL.format(nnimanager_ip), DEFAULT_REST_PORT, NNI_EXP_ID, NNI_TRIAL_JOB_ID), json.dumps(result), 10) + response = rest_post(gen_update_metrics_url(BASE_URL.format(nnimanager_ip), nnimanager_port, NNI_EXP_ID, NNI_TRIAL_JOB_ID), json.dumps(result), 10) print('Response code is {}'.format(response.status_code)) except Exception: #TODO error logging to file diff --git a/tools/nni_trial_tool/trial_keeper.py b/tools/nni_trial_tool/trial_keeper.py index ab1b42ac64..ae8fa8a8d5 100644 --- a/tools/nni_trial_tool/trial_keeper.py +++ b/tools/nni_trial_tool/trial_keeper.py @@ -48,7 +48,7 @@ def main_loop(args): while True: retCode = process.poll() ## Read experiment metrics, to avoid missing metrics - read_experiment_metrics(args.nnimanager_ip) + read_experiment_metrics(args.nnimanager_ip, args.nnimanager_port) if retCode is not None: print('subprocess terminated. Exit code is {}. Quit'.format(retCode)) @@ -80,7 +80,8 @@ def trial_keeper_help_info(*args): PARSER = argparse.ArgumentParser() PARSER.set_defaults(func=trial_keeper_help_info) PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process') - PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager IP') + PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager rest server IP') + PARSER.add_argument('--nnimanager_port', type=str, default='8081', help='NNI manager rest server port') PARSER.add_argument('--pai_hdfs_output_dir', type=str, help='the output dir of hdfs') PARSER.add_argument('--pai_hdfs_host', type=str, help='the host of hdfs') PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs')