diff --git a/src/visualizers/panels/TensorPlotter/TensorPlotterControl.js b/src/visualizers/panels/TensorPlotter/TensorPlotterControl.js index a0a4f67f9..447f23e04 100644 --- a/src/visualizers/panels/TensorPlotter/TensorPlotterControl.js +++ b/src/visualizers/panels/TensorPlotter/TensorPlotterControl.js @@ -1,18 +1,99 @@ /*globals define */ -/** - * Generated by VisualizerGenerator 1.7.0 from webgme on Mon May 04 2020 17:09:31 GMT-0500 (Central Daylight Time). - */ define([ 'panels/InteractiveExplorer/InteractiveExplorerControl', + 'text!./explorer_helpers.py', ], function ( InteractiveExplorerControl, + HELPERS_PY, ) { 'use strict'; class TensorPlotterControl extends InteractiveExplorerControl { + initializeWidgetHandlers (widget) { + super.initializeWidgetHandlers(widget); + widget.getPoints = lineInfo => this.getPoints(lineInfo); + widget.getColorValues = lineInfo => this.getColorValues(lineInfo); + widget.getMetadata = desc => this.getMetadata(desc); + } + + async onComputeInitialized (session) { + super.onComputeInitialized(session); + this._widget.artifactLoader.session = session; + const initCode = await this.getInitializationCode(); + await session.addFile('utils/init.py', initCode); + await session.addFile('utils/explorer_helpers.py', HELPERS_PY); + } + + async getPoints (lineInfo) { + const {data, dataSlice=''} = lineInfo; + const {pyImport, varName} = this.getImportCode(data); + const command = [ + 'import utils.init', + pyImport, + 'from utils.explorer_helpers import print_points', + `print_points(${varName}${dataSlice})` + ].join('\n'); + const stdout = await this.execPy(command); + return JSON.parse(stdout); + } + + async getColorValues (lineInfo) { + const {colorData, colorDataSlice='', startColor, endColor} = lineInfo; + const {pyImport, varName} = this.getImportCode(colorData); + const command = [ + 'import utils.init', + pyImport, + 'from utils.explorer_helpers import print_colors', + `data = ${varName}${colorDataSlice}`, + `print_colors(data, "${startColor}", "${endColor}")` + ].join('\n'); + const stdout = await this.execPy(command); + return JSON.parse(stdout); + } + + async getMetadata (desc) { + const {name} = desc; + const {pyImport, varName} = this.getImportCode(name); + const command = [ + 'import utils.init', + pyImport, + 'from utils.explorer_helpers import print_metadata', + `print_metadata("${varName}", ${varName})`, + ].join('\n'); + const stdout = await this.execPy(command); + return JSON.parse(stdout); + } + + getImportCode (artifactName) { + const pyName = artifactName.replace(/\..*$/, ''); + const [modName, ...accessors] = pyName.split('['); + const pyImport = `from artifacts.${modName} import data as ${modName}`; + const accessor = accessors.length ? '[' + accessors.join('[') : ''; + const varName = modName + accessor; + return { + pyImport, varName + }; + } + + async execPy(code) { + try { + const i = ++this.cmdCount; + await this.session.addFile(`cmd_${i}.py`, code); + const {stdout} = await this.session.exec(`python cmd_${i}.py`); + await this.session.removeFile(`cmd_${i}.py`); + return stdout; + } catch (err) { + const {stderr} = err.jobResult; + const wrappedError = new Error(err.message); + wrappedError.stderr = stderr; + wrappedError.code = code; + throw wrappedError; + } + } + getObjectDescriptor(nodeId) { const desc = super.getObjectDescriptor(nodeId); diff --git a/src/visualizers/widgets/TensorPlotter/files/explorer_helpers.py b/src/visualizers/panels/TensorPlotter/explorer_helpers.py similarity index 100% rename from src/visualizers/widgets/TensorPlotter/files/explorer_helpers.py rename to src/visualizers/panels/TensorPlotter/explorer_helpers.py diff --git a/src/visualizers/widgets/TrainKeras/Main.py b/src/visualizers/panels/TrainKeras/Main.py similarity index 100% rename from src/visualizers/widgets/TrainKeras/Main.py rename to src/visualizers/panels/TrainKeras/Main.py diff --git a/src/visualizers/panels/TrainKeras/TrainKerasControl.js b/src/visualizers/panels/TrainKeras/TrainKerasControl.js index e51d757f9..3fa234763 100644 --- a/src/visualizers/panels/TrainKeras/TrainKerasControl.js +++ b/src/visualizers/panels/TrainKeras/TrainKerasControl.js @@ -3,33 +3,58 @@ define([ 'panels/InteractiveExplorer/InteractiveExplorerControl', 'deepforge/globals', + 'deepforge/PromiseEvents', + 'deepforge/compute/interactive/message', 'deepforge/CodeGenerator', + 'plugin/GenerateJob/GenerateJob/templates/index', + 'text!./Main.py', + 'text!./TrainOperation.py', 'deepforge/OperationCode', './JSONImporter', + 'deepforge/Constants', 'js/Constants', 'q', 'underscore', ], function ( InteractiveExplorerControl, DeepForge, + PromiseEvents, + Message, CodeGenerator, + JobTemplates, + MainCode, + TrainOperation, OperationCode, Importer, CONSTANTS, + GME_CONSTANTS, Q, _, ) { 'use strict'; + MainCode = _.template(MainCode); + const GetTrainCode = _.template(TrainOperation); class TrainKerasControl extends InteractiveExplorerControl { + constructor() { + super(...arguments); + this.modelCount = 0; + } + initializeWidgetHandlers (widget) { super.initializeWidgetHandlers(widget); const self = this; widget.getArchitectureCode = id => this.getArchitectureCode(id); widget.saveModel = function() {return self.saveModel(...arguments);}; widget.getNodeSnapshot = id => this.getNodeSnapshot(id); + widget.stopCurrentTask = () => this.stopTask(this.currentTrainTask); + widget.train = config => this.train(config); + widget.isTrainingModel = () => this.isTrainingModel(); + widget.getCurrentModelID = () => this.getCurrentModelID(); + widget.createModelInfo = config => this.createModelInfo(config); + widget.addArtifact = (dataset, auth) => this.addArtifact(dataset, auth); } async getNodeSnapshot(id) { @@ -41,14 +66,112 @@ define([ return state; } - async saveModel(modelInfo, storage, session) { - const metadata = (await session.forkAndRun( + async onComputeInitialized(session) { + super.onComputeInitialized(session); + const initCode = await this.getInitializationCode(); + await session.addFile('utils/init.py', initCode); + await session.addFile('plotly_backend.py', JobTemplates.MATPLOTLIB_BACKEND); + await session.setEnvVar('MPLBACKEND', 'module://plotly_backend'); + } + + async stopTask(task) { + await this.session.kill(task); + } + + async addArtifact(dataset, auth) { + await this.session.addArtifact(dataset.name, dataset.dataInfo, dataset.type, auth); + } + + async createModelInfo(config) { + this.modelCount++; + const saveName = this.getCurrentModelID(); + const architecture = await this.getNodeSnapshot(config.architecture.id); + return { + id: saveName, + path: saveName, + name: saveName, + config, + architecture + }; + } + + getCurrentModelID() { + return `model_${this.modelCount}`; + } + + train(modelInfo) { + const self = this; + return PromiseEvents.new(async function(resolve) { + this.emit('update', 'Generating Code'); + await self.initTrainingCode(modelInfo); + this.emit('update', 'Training...'); + const trainTask = self.session.spawn('python start_train.py'); + self.currentTrainTask = trainTask; + self.currentTrainTask.on(Message.STDOUT, data => { + let line = data.toString(); + if (line.startsWith(CONSTANTS.START_CMD)) { + line = line.substring(CONSTANTS.START_CMD.length + 1); + const splitIndex = line.indexOf(' '); + const cmd = line.substring(0, splitIndex); + const content = JSON.parse(line.substring(splitIndex + 1)); + if (cmd === 'PLOT') { + this.emit('plot', content); + } else { + console.error('Unrecognized command:', cmd); + } + } + }); + let stderr = ''; + self.currentTrainTask.on(Message.STDERR, data => stderr += data.toString()); + self.currentTrainTask.on(Message.COMPLETE, exitCode => { + if (exitCode) { + this.emit('error', stderr); + } else { + this.emit('end'); + } + if (self.currentTrainTask === trainTask) { + self.currentTrainTask = null; + } + resolve(); + }); + }); + } + + async initTrainingCode(modelInfo) { + const {config} = modelInfo; + const {dataset, architecture, path, loss, optimizer} = config; + const archCode = await this.getArchitectureCode(architecture.id); + loss.arguments.concat(optimizer.arguments).forEach(arg => { + let pyValue = arg.value.toString(); + if (arg.type === 'boolean') { + pyValue = arg.value ? 'True' : 'False'; + } else if (arg.type === 'enum') { + pyValue = `"${arg.value}"`; + } + arg.pyValue = pyValue; + }); + await this.session.addFile('start_train.py', MainCode({ + dataset, + path, + archCode + })); + const trainPy = GetTrainCode(config); + await this.session.addFile('operations/train.py', trainPy); + } + + isTrainingModel() { + return !!this.currentTrainTask; + } + + async saveModel(modelInfo, storage) { + modelInfo.code = GetTrainCode(modelInfo.config); + const metadata = (await this.session.forkAndRun( session => session.exec(`cat outputs/${modelInfo.path}/metadata.json`) )).stdout; const {type} = JSON.parse(metadata); const projectId = this.client.getProjectInfo()._id; const savePath = `${projectId}/artifacts/${modelInfo.name}`; - const dataInfo = await session.forkAndRun( + const dataInfo = await this.session.forkAndRun( session => session.saveArtifact( `outputs/${modelInfo.path}/data`, savePath, @@ -239,13 +362,13 @@ define([ .forEach(event => { switch (event.etype) { - case CONSTANTS.TERRITORY_EVENT_LOAD: + case GME_CONSTANTS.TERRITORY_EVENT_LOAD: this.onResourceLoad(event.eid); break; - case CONSTANTS.TERRITORY_EVENT_UPDATE: + case GME_CONSTANTS.TERRITORY_EVENT_UPDATE: this.onResourceUpdate(event.eid); break; - case CONSTANTS.TERRITORY_EVENT_UNLOAD: + case GME_CONSTANTS.TERRITORY_EVENT_UNLOAD: this.onResourceUnload(event.eid); break; default: diff --git a/src/visualizers/panels/TrainKeras/TrainKerasPanel.js b/src/visualizers/panels/TrainKeras/TrainKerasPanel.js index 0a3a939b5..91787db0b 100644 --- a/src/visualizers/panels/TrainKeras/TrainKerasPanel.js +++ b/src/visualizers/panels/TrainKeras/TrainKerasPanel.js @@ -1,101 +1,70 @@ -/*globals define, _, WebGMEGlobal*/ +/*globals define, WebGMEGlobal*/ /** * Generated by VisualizerGenerator 1.7.0 from webgme on Mon Jul 27 2020 14:55:57 GMT-0500 (Central Daylight Time). */ define([ 'js/PanelBase/PanelBaseWithHeader', - 'js/PanelManager/IActivePanel', + 'panels/InteractiveEditor/InteractiveEditorPanel', 'widgets/TrainKeras/TrainKerasWidget', - './TrainKerasControl' + './TrainKerasControl', ], function ( PanelBaseWithHeader, - IActivePanel, + InteractiveEditorPanel, TrainKerasWidget, - TrainKerasControl + TrainKerasControl, ) { 'use strict'; - function TrainKerasPanel(layoutManager, params) { - var options = {}; - //set properties from options - options[PanelBaseWithHeader.OPTIONS.LOGGER_INSTANCE_NAME] = 'TrainKerasPanel'; - options[PanelBaseWithHeader.OPTIONS.FLOATING_TITLE] = true; - - //call parent's constructor - PanelBaseWithHeader.apply(this, [options, layoutManager]); - - this._client = params.client; - this._embedded = params.embedded; - - //initialize UI - this._initialize(); - - this.logger.debug('ctor finished'); + class TrainKerasPanel extends InteractiveEditorPanel { + constructor(layoutManager, params) { + const config = { + name: 'TrainKeras', + Control: TrainKerasControl, + Widget: TrainKerasWidget, + }; + super(config, params); + + this.logger.debug('ctor finished'); + } + + /* OVERRIDE FROM WIDGET-WITH-HEADER */ + /* METHOD CALLED WHEN THE WIDGET'S READ-ONLY PROPERTY CHANGES */ + onReadOnlyChanged(isReadOnly) { + //apply parent's onReadOnlyChanged + PanelBaseWithHeader.prototype.onReadOnlyChanged.call(this, isReadOnly); + + } + + onResize(width, height) { + this.logger.debug('onResize --> width: ' + width + ', height: ' + height); + this.widget.onWidgetContainerResize(width, height); + } + + /* * * * * * * * Visualizer life cycle callbacks * * * * * * * */ + destroy() { + this.control.destroy(); + this.widget.destroy(); + + PanelBaseWithHeader.prototype.destroy.call(this); + WebGMEGlobal.KeyboardManager.setListener(undefined); + WebGMEGlobal.Toolbar.refresh(); + } + + onActivate() { + this.widget.onActivate(); + this.control.onActivate(); + WebGMEGlobal.KeyboardManager.setListener(this.widget); + WebGMEGlobal.Toolbar.refresh(); + } + + onDeactivate() { + this.widget.onDeactivate(); + this.control.onDeactivate(); + WebGMEGlobal.KeyboardManager.setListener(undefined); + WebGMEGlobal.Toolbar.refresh(); + } } - //inherit from PanelBaseWithHeader - _.extend(TrainKerasPanel.prototype, PanelBaseWithHeader.prototype); - _.extend(TrainKerasPanel.prototype, IActivePanel.prototype); - - TrainKerasPanel.prototype._initialize = function () { - var self = this; - - //set Widget title - this.setTitle(''); - - this.widget = new TrainKerasWidget(this.logger, this.$el); - - this.widget.setTitle = function (title) { - self.setTitle(title); - }; - - this.control = new TrainKerasControl({ - logger: this.logger, - client: this._client, - embedded: this._embedded, - widget: this.widget - }); - - this.onActivate(); - }; - - /* OVERRIDE FROM WIDGET-WITH-HEADER */ - /* METHOD CALLED WHEN THE WIDGET'S READ-ONLY PROPERTY CHANGES */ - TrainKerasPanel.prototype.onReadOnlyChanged = function (isReadOnly) { - //apply parent's onReadOnlyChanged - PanelBaseWithHeader.prototype.onReadOnlyChanged.call(this, isReadOnly); - - }; - - TrainKerasPanel.prototype.onResize = function (width, height) { - this.logger.debug('onResize --> width: ' + width + ', height: ' + height); - this.widget.onWidgetContainerResize(width, height); - }; - - /* * * * * * * * Visualizer life cycle callbacks * * * * * * * */ - TrainKerasPanel.prototype.destroy = function () { - this.control.destroy(); - this.widget.destroy(); - - PanelBaseWithHeader.prototype.destroy.call(this); - WebGMEGlobal.KeyboardManager.setListener(undefined); - WebGMEGlobal.Toolbar.refresh(); - }; - - TrainKerasPanel.prototype.onActivate = function () { - this.widget.onActivate(); - this.control.onActivate(); - WebGMEGlobal.KeyboardManager.setListener(this.widget); - WebGMEGlobal.Toolbar.refresh(); - }; - - TrainKerasPanel.prototype.onDeactivate = function () { - this.widget.onDeactivate(); - this.control.onDeactivate(); - WebGMEGlobal.KeyboardManager.setListener(undefined); - WebGMEGlobal.Toolbar.refresh(); - }; - return TrainKerasPanel; }); diff --git a/src/visualizers/widgets/TrainKeras/TrainOperation.py b/src/visualizers/panels/TrainKeras/TrainOperation.py similarity index 100% rename from src/visualizers/widgets/TrainKeras/TrainOperation.py rename to src/visualizers/panels/TrainKeras/TrainOperation.py diff --git a/src/visualizers/widgets/TensorPlotter/ArtifactLoader.js b/src/visualizers/widgets/TensorPlotter/ArtifactLoader.js index ae49f2c81..51485bff6 100644 --- a/src/visualizers/widgets/TensorPlotter/ArtifactLoader.js +++ b/src/visualizers/widgets/TensorPlotter/ArtifactLoader.js @@ -14,7 +14,6 @@ define([ class ArtifactLoader extends EventEmitter { constructor(container) { super(); - this.session = null; this.$el = container; this.$el.addClass('artifact-loader'); this.$el.append($(Html)); diff --git a/src/visualizers/widgets/TensorPlotter/TensorPlotterWidget.js b/src/visualizers/widgets/TensorPlotter/TensorPlotterWidget.js index ab5241290..8e5e9d2f9 100644 --- a/src/visualizers/widgets/TensorPlotter/TensorPlotterWidget.js +++ b/src/visualizers/widgets/TensorPlotter/TensorPlotterWidget.js @@ -3,23 +3,19 @@ define([ 'widgets/InteractiveExplorer/InteractiveExplorerWidget', 'deepforge/storage/index', - 'deepforge/compute/interactive/session-with-queue', 'webgme-plotly/plotly.min', './PlotEditor', './ArtifactLoader', 'underscore', - 'text!./files/explorer_helpers.py', 'deepforge/viz/InformDialog', 'css!./styles/TensorPlotterWidget.css', ], function ( InteractiveExplorerWidget, Storage, - Session, Plotly, PlotEditor, ArtifactLoader, _, - HELPERS_PY, InformDialog, ) { 'use strict'; @@ -33,7 +29,6 @@ define([ this.cmdCount = 0; this.currentPlotData = null; - this.session = null; this.$el = container; this.$el.addClass(WIDGET_CLASS); const row = $('
${code}
${stderr}
`;
- const dialog = new InformDialog('Plotting failed.', msg);
- dialog.show();
- throw err;
- }
- }
-
- async getPoints (lineInfo) {
- const {data, dataSlice=''} = lineInfo;
- const {pyImport, varName} = this.getImportCode(data);
- const command = [
- 'import utils.init',
- pyImport,
- 'from utils.explorer_helpers import print_points',
- `print_points(${varName}${dataSlice})`
- ].join('\n');
- const stdout = await this.execPy(command);
- return JSON.parse(stdout);
- }
-
- async getColorValues (lineInfo) {
- const {colorData, colorDataSlice='', startColor, endColor} = lineInfo;
- const {pyImport, varName} = this.getImportCode(colorData);
- const command = [
- 'import utils.init',
- pyImport,
- 'from utils.explorer_helpers import print_colors',
- `data = ${varName}${colorDataSlice}`,
- `print_colors(data, "${startColor}", "${endColor}")`
- ].join('\n');
- const stdout = await this.execPy(command);
- return JSON.parse(stdout);
- }
-
- async getMetadata (desc) {
- const {name} = desc;
- const {pyImport, varName} = this.getImportCode(name);
- const command = [
- 'import utils.init',
- pyImport,
- 'from utils.explorer_helpers import print_metadata',
- `print_metadata("${varName}", ${varName})`,
- ].join('\n');
- const stdout = await this.execPy(command);
- return JSON.parse(stdout);
- }
-
- getImportCode (artifactName) {
- const pyName = artifactName.replace(/\..*$/, '');
- const [modName, ...accessors] = pyName.split('[');
- const pyImport = `from artifacts.${modName} import data as ${modName}`;
- const accessor = accessors.length ? '[' + accessors.join('[') : '';
- const varName = modName + accessor;
- return {
- pyImport, varName
- };
- }
-
async getPlotData (line) {
const {shape} = line;
const dim = shape[1];
@@ -228,18 +148,27 @@ define([
}
async getPlotlyJSON (figureData) {
- const layout = _.pick(figureData, ['title', 'xaxis', 'yaxis']);
- if (layout.xaxis) {
- layout.xaxis = {title: layout.xaxis};
- }
- if (layout.yaxis) {
- layout.yaxis = {title: layout.yaxis};
- }
- const data = [];
- for (let i = 0; i < figureData.data.length; i++) {
- data.push(await this.getPlotData(figureData.data[i]));
+ try {
+ const layout = _.pick(figureData, ['title', 'xaxis', 'yaxis']);
+ if (layout.xaxis) {
+ layout.xaxis = {title: layout.xaxis};
+ }
+ if (layout.yaxis) {
+ layout.yaxis = {title: layout.yaxis};
+ }
+ const data = [];
+ for (let i = 0; i < figureData.data.length; i++) {
+ data.push(await this.getPlotData(figureData.data[i]));
+ }
+ return {data, layout};
+ } catch (err) {
+ const {stderr, code} = err;
+ const msg = `Command:${code}
${stderr}
`;
+ const dialog = new InformDialog('Plotting failed.', msg);
+ dialog.show();
+ throw err;
}
- return {data, layout};
}
async updatePlot (figureData) {
@@ -261,9 +190,6 @@ define([
/* * * * * * * * Visualizer life cycle callbacks * * * * * * * */
destroy () {
Plotly.purge(this.$plot[0]);
- if (this.session) {
- this.session.close();
- }
}
onActivate () {
diff --git a/src/visualizers/widgets/TrainKeras/TrainKerasWidget.js b/src/visualizers/widgets/TrainKeras/TrainKerasWidget.js
index c9822b628..3aeebdfca 100644
--- a/src/visualizers/widgets/TrainKeras/TrainKerasWidget.js
+++ b/src/visualizers/widgets/TrainKeras/TrainKerasWidget.js
@@ -2,16 +2,10 @@
define([
'./build/TrainDashboard',
- 'plugin/GenerateJob/GenerateJob/templates/index',
- 'deepforge/Constants',
'deepforge/storage/index',
'widgets/InteractiveEditor/InteractiveEditorWidget',
'deepforge/viz/ConfigDialog',
- 'deepforge/compute/interactive/message',
- 'deepforge/compute/line-collector',
'webgme-plotly/plotly.min',
- 'text!./TrainOperation.py',
- 'text!./Main.py',
'deepforge/viz/StorageHelpers',
'deepforge/viz/ConfirmDialog',
'deepforge/viz/InformDialog',
@@ -21,16 +15,10 @@ define([
'css!./styles/TrainKerasWidget.css',
], function (
TrainDashboard,
- JobTemplates,
- CONSTANTS,
Storage,
InteractiveEditor,
ConfigDialog,
- Message,
- LineCollector,
Plotly,
- TrainOperation,
- MainCode,
StorageHelpers,
ConfirmDialog,
InformDialog,
@@ -40,9 +28,7 @@ define([
'use strict';
const WIDGET_CLASS = 'train-keras';
- const GetTrainCode = _.template(TrainOperation);
const DashboardSchemas = JSON.parse(SchemaText);
- MainCode = _.template(MainCode);
class TrainKerasWidget extends InteractiveEditor {
constructor(logger, container) {
@@ -51,7 +37,7 @@ define([
this.dashboard.initialize(Plotly, DashboardSchemas);
this.dashboard.events().addEventListener(
'onTrainClicked',
- () => this.train(this.dashboard.data())
+ () => this.onTrainClicked()
);
this.dashboard.events().addEventListener(
'saveModel',
@@ -61,25 +47,16 @@ define([
'showModelInfo',
event => this.onShowModelInfo(event.detail)
);
- this.modelCount = 0;
container.addClass(WIDGET_CLASS);
- this.currentTrainTask = null;
this.loadedData = [];
}
- async onComputeInitialized(session) {
- const initCode = await this.getInitializationCode();
- await session.addFile('utils/init.py', initCode);
- await session.addFile('plotly_backend.py', JobTemplates.MATPLOTLIB_BACKEND);
- await session.setEnvVar('MPLBACKEND', 'module://plotly_backend');
- }
-
isDataLoaded(dataset) {
return this.loadedData.find(data => _.isEqual(data, dataset));
}
- async train(config) {
- if (this.currentTrainTask) {
+ async onTrainClicked() {
+ if (this.isTrainingModel()) {
const title = 'Stop Current Training';
const body = 'Would you like to stop the current training to train a model with the new configuration?';
const dialog = new ConfirmDialog(title, body);
@@ -90,75 +67,43 @@ define([
}
this.dashboard.setModelState(this.getCurrentModelID(), 'Canceled');
- await this.session.kill(this.currentTrainTask);
+ await this.stopCurrentTask();
}
- this.modelCount++;
- const saveName = this.getCurrentModelID();
- const architecture = await this.getNodeSnapshot(config.architecture.id);
- const modelInfo = {
- id: saveName,
- path: saveName,
- name: saveName,
- state: 'Fetching Data...',
- config,
- architecture
- };
- this.dashboard.addModel(modelInfo);
+ const config = this.dashboard.data();
const {dataset} = config;
+ const modelInfo = await this.createModelInfo(config);
+ modelInfo.state = 'Fetching Data';
+ this.dashboard.addModel(modelInfo);
+
if (!this.isDataLoaded(dataset)) {
this.loadedData.push(dataset);
const auth = await StorageHelpers.getAuthenticationConfig(dataset.dataInfo);
- await this.session.addArtifact(dataset.name, dataset.dataInfo, dataset.type, auth);
+ await this.addArtifact(dataset, auth);
}
- this.dashboard.setModelState(this.getCurrentModelID(), 'Generating Code');
- const archCode = await this.getArchitectureCode(config.architecture.id);
- config.loss.arguments.concat(config.optimizer.arguments).forEach(arg => {
- let pyValue = arg.value.toString();
- if (arg.type === 'boolean') {
- pyValue = arg.value ? 'True' : 'False';
- } else if (arg.type === 'enum') {
- pyValue = `"${arg.value}"`;
- }
- arg.pyValue = pyValue;
- });
- await this.session.addFile('start_train.py', MainCode({
- dataset,
- path: modelInfo.path,
- archCode
- }));
- const trainPy = GetTrainCode(config);
- await this.session.addFile('operations/train.py', trainPy);
- this.dashboard.setModelState(this.getCurrentModelID(), 'Training...');
- const trainTask = this.session.spawn('python start_train.py');
- this.currentTrainTask = trainTask;
- this.currentTrainTask.on(Message.STDOUT, data => {
- let line = data.toString();
- if (line.startsWith(CONSTANTS.START_CMD)) {
- line = line.substring(CONSTANTS.START_CMD.length + 1);
- const splitIndex = line.indexOf(' ');
- const cmd = line.substring(0, splitIndex);
- const content = line.substring(splitIndex + 1);
- this.parseMetadata(cmd, JSON.parse(content));
- }
- });
- let stderr = '';
- this.currentTrainTask.on(Message.STDERR, data => stderr += data.toString());
- this.currentTrainTask.on(Message.COMPLETE, exitCode => {
- if (exitCode) {
- this.dashboard.setModelState(modelInfo.id, 'Error Occurred', stderr);
- } else {
- this.dashboard.setModelState(modelInfo.id);
- }
- if (this.currentTrainTask === trainTask) {
- this.currentTrainTask = null;
- }
- });
+ const createTrainTask = this.train(modelInfo);
+ createTrainTask.on(
+ 'update',
+ status => this.dashboard.setModelState(modelInfo.id, status)
+ );
+ createTrainTask.on(
+ 'plot',
+ plotData => this.dashboard.setPlotData(modelInfo.id, plotData)
+ );
+ createTrainTask.on(
+ 'error',
+ stderr => this.dashboard.setModelState(modelInfo.id, 'Error Occurred', stderr)
+ );
+ createTrainTask.on(
+ 'end',
+ () => this.dashboard.setModelState(modelInfo.id)
+ );
+ await createTrainTask;
}
async onShowModelInfo(modelInfo) {
- let body = modelInfo.info.replace(/\n/g, '