Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix remote machine connection logic #2725

Merged
merged 18 commits into from
Aug 7, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class RemoteMachineTrainingService implements TrainingService {
private nniManagerIpConfig?: NNIManagerIpConfig;
private versionCheck: boolean = true;
private logCollection: string;
private sshConnectionPromises: any[];

constructor(@component.Inject timer: ObservableTimer) {
this.metricsEmitter = new EventEmitter();
Expand All @@ -66,6 +67,7 @@ class RemoteMachineTrainingService implements TrainingService {
this.machineCopyExpCodeDirPromiseMap = new Map<RemoteMachineMeta, Promise<void>>();
this.machineExecutorManagerMap = new Map<RemoteMachineMeta, ExecutorManager>();
this.jobQueue = [];
this.sshConnectionPromises = [];
this.expRootDir = getExperimentRootDir();
this.timer = timer;
this.log = getLogger();
Expand All @@ -81,6 +83,12 @@ class RemoteMachineTrainingService implements TrainingService {
await restServer.start();
restServer.setEnableVersionCheck = this.versionCheck;
this.log.info('Run remote machine training service.');
if (this.sshConnectionPromises.length > 0) {
await Promise.all(this.sshConnectionPromises);
this.log.info('ssh connection initialized!');
// set sshConnectionPromises to [] to avoid log information duplicated
this.sshConnectionPromises = [];
}
while (!this.stopping) {
while (this.jobQueue.length > 0) {
this.updateGpuReservation();
Expand Down Expand Up @@ -409,7 +417,6 @@ class RemoteMachineTrainingService implements TrainingService {
//TO DO: verify if value's format is wrong, and json parse failed, how to handle error
const rmMetaList: RemoteMachineMeta[] = <RemoteMachineMeta[]>JSON.parse(machineList);

const connectionPromises = [];
for (const rmMeta of rmMetaList) {
rmMeta.occupiedGpuIndexMap = new Map<number, number>();
const executorManager: ExecutorManager = new ExecutorManager(rmMeta);
Expand All @@ -418,11 +425,9 @@ class RemoteMachineTrainingService implements TrainingService {
this.log.debug(`reached ${executor.name}`);
this.machineExecutorManagerMap.set(rmMeta, executorManager);
this.log.debug(`initializing ${executor.name}`);
connectionPromises.push(this.initRemoteMachineOnConnected(rmMeta, executor));
this.log.info(`connected to ${executor.name}`);
this.sshConnectionPromises.push(this.initRemoteMachineOnConnected(rmMeta, executor));
this.log.info(`connecting to ${executor.name}`);
}

await Promise.all(connectionPromises);
}

private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, executor: ShellExecutor): Promise<void> {
Expand Down