Skip to content

Commit

Permalink
feat: infer column type using runner (#1182)
Browse files Browse the repository at this point in the history
Closes #994 

### Summary of Changes

EDA adds another execution in getting the Table placeholder to get
column types with the .getColumn().isNumeric call that based on that
evals to numeric or categorical type. In future also consider the
temporal type, but will have to overhaul the profiling and filters for
that.
  • Loading branch information
SmiteDeluxe authored May 28, 2024
1 parent d83c3d4 commit 846f404
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions packages/safe-ds-vscode/src/extension/eda/apis/runnerApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ export class RunnerApi {
private sdsStringForCorrelationHeatmap(tablePlaceholder: string, newPlaceholderName: string) {
return 'val ' + newPlaceholderName + ' = ' + tablePlaceholder + '.plot.correlationHeatmap(); \n';
}

private sdsStringForIsNumeric(tablePlaceholder: string, columnName: string, newPlaceholderName: string) {
return (
'val ' + newPlaceholderName + ' = ' + tablePlaceholder + '.getColumn("' + columnName + '").isNumeric; \n'
);
}
//#endregion

//#region Placeholder handling
Expand All @@ -283,6 +289,9 @@ export class RunnerApi {
return;
}
this.services.runtime.PythonServer.removeMessageCallback('placeholder_value', placeholderValueCallback);
safeDsLogger.debug(
'Got placeholder value: ' + JSON.stringify(message.data.value).slice(0, 100) + '...',
);
resolve(message.data.value);
};

Expand All @@ -304,8 +313,28 @@ export class RunnerApi {
//#region Table fetching
public async getTableByPlaceholder(tableName: string, pipelineExecutionId: string): Promise<Table | undefined> {
safeDsLogger.debug('Getting table by placeholder: ' + tableName);

const pythonTableColumns = await this.getPlaceholderValue(tableName, pipelineExecutionId);
if (pythonTableColumns) {
// Get Column Types
safeDsLogger.debug('Getting column types for table: ' + tableName);
let sdsLines = '';
let placeholderNames: string[] = [];
let columnNameToPlaceholderIsNumericNameMap = new Map<string, string>();
for (const columnName of Object.keys(pythonTableColumns)) {
const newPlaceholderName = this.genPlaceholderName(columnName + '_type');
columnNameToPlaceholderIsNumericNameMap.set(columnName, newPlaceholderName);
placeholderNames.push(newPlaceholderName);
sdsLines += this.sdsStringForIsNumeric(tableName, columnName, newPlaceholderName);
}

await this.addToAndExecutePipeline(pipelineExecutionId, sdsLines, placeholderNames);
const columnIsNumeric = new Map<string, string>();
for (const [columnName, placeholderName] of columnNameToPlaceholderIsNumericNameMap) {
const columnType = await this.getPlaceholderValue(placeholderName, pipelineExecutionId);
columnIsNumeric.set(columnName, columnType as string);
}

const table: Table = {
totalRows: 0,
name: tableName,
Expand All @@ -322,8 +351,7 @@ export class RunnerApi {
currentMax = columnValues.length;
}

const isNumerical = typeof columnValues[0] === 'number';
const columnType = isNumerical ? 'numerical' : 'categorical';
const columnType = columnIsNumeric.get(columnName) ? 'numerical' : 'categorical';

const column: Column = {
name: columnName,
Expand Down

0 comments on commit 846f404

Please sign in to comment.