From 846f404f28e89399f79dd344c6d3d89a5018b2bd Mon Sep 17 00:00:00 2001 From: Jonas B <97200640+SmiteDeluxe@users.noreply.github.com> Date: Tue, 28 May 2024 04:56:25 -0600 Subject: [PATCH] feat: infer column type using runner (#1182) Closes #994 ### Summary of Changes EDA adds another execution in getting the Table placeholder to get column types with the .getColumn().isNumeric call that based on that evals to numeric or categorical type. In future also consider the temporal type, but will have to overhaul the profiling and filters for that. --- .../src/extension/eda/apis/runnerApi.ts | 32 +++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/packages/safe-ds-vscode/src/extension/eda/apis/runnerApi.ts b/packages/safe-ds-vscode/src/extension/eda/apis/runnerApi.ts index e55e55ce3..c09eafe0b 100644 --- a/packages/safe-ds-vscode/src/extension/eda/apis/runnerApi.ts +++ b/packages/safe-ds-vscode/src/extension/eda/apis/runnerApi.ts @@ -263,6 +263,12 @@ export class RunnerApi { private sdsStringForCorrelationHeatmap(tablePlaceholder: string, newPlaceholderName: string) { return 'val ' + newPlaceholderName + ' = ' + tablePlaceholder + '.plot.correlationHeatmap(); \n'; } + + private sdsStringForIsNumeric(tablePlaceholder: string, columnName: string, newPlaceholderName: string) { + return ( + 'val ' + newPlaceholderName + ' = ' + tablePlaceholder + '.getColumn("' + columnName + '").isNumeric; \n' + ); + } //#endregion //#region Placeholder handling @@ -283,6 +289,9 @@ export class RunnerApi { return; } this.services.runtime.PythonServer.removeMessageCallback('placeholder_value', placeholderValueCallback); + safeDsLogger.debug( + 'Got placeholder value: ' + JSON.stringify(message.data.value).slice(0, 100) + '...', + ); resolve(message.data.value); }; @@ -304,8 +313,28 @@ export class RunnerApi { //#region Table fetching public async getTableByPlaceholder(tableName: string, pipelineExecutionId: string): Promise { safeDsLogger.debug('Getting table by placeholder: ' + tableName); + const pythonTableColumns = await this.getPlaceholderValue(tableName, pipelineExecutionId); if (pythonTableColumns) { + // Get Column Types + safeDsLogger.debug('Getting column types for table: ' + tableName); + let sdsLines = ''; + let placeholderNames: string[] = []; + let columnNameToPlaceholderIsNumericNameMap = new Map(); + for (const columnName of Object.keys(pythonTableColumns)) { + const newPlaceholderName = this.genPlaceholderName(columnName + '_type'); + columnNameToPlaceholderIsNumericNameMap.set(columnName, newPlaceholderName); + placeholderNames.push(newPlaceholderName); + sdsLines += this.sdsStringForIsNumeric(tableName, columnName, newPlaceholderName); + } + + await this.addToAndExecutePipeline(pipelineExecutionId, sdsLines, placeholderNames); + const columnIsNumeric = new Map(); + for (const [columnName, placeholderName] of columnNameToPlaceholderIsNumericNameMap) { + const columnType = await this.getPlaceholderValue(placeholderName, pipelineExecutionId); + columnIsNumeric.set(columnName, columnType as string); + } + const table: Table = { totalRows: 0, name: tableName, @@ -322,8 +351,7 @@ export class RunnerApi { currentMax = columnValues.length; } - const isNumerical = typeof columnValues[0] === 'number'; - const columnType = isNumerical ? 'numerical' : 'categorical'; + const columnType = columnIsNumeric.get(columnName) ? 'numerical' : 'categorical'; const column: Column = { name: columnName,