Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Data Frame Analytics: Fix feature importance cell value and decision path chart #82011

Merged
merged 7 commits into from
Oct 30, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { EuiDataGridSorting } from '@elastic/eui';

import { multiColumnSortFactory } from './common';

describe('Transform: Define Pivot Common', () => {
describe('Data Frame Analytics: Data Grid Common', () => {
test('multiColumnSortFactory()', () => {
const data = [
{ s: 'a', n: 1 },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ import {
KBN_FIELD_TYPES,
} from '../../../../../../../src/plugins/data/public';

import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
import { extractErrorMessage } from '../../../../common/util/errors';
import { FeatureImportance, TopClasses } from '../../../../common/types/feature_importance';

import {
BASIC_NUMERICAL_TYPES,
Expand Down Expand Up @@ -158,6 +160,50 @@ export const getDataGridSchemaFromKibanaFieldType = (
return schema;
};

/**
* Helper to transform feature importance flattened fields with arrays back to object structure
*
* @param row - EUI data grid data row
* @param mlResultsField - Data frame analytics results field
* @returns nested object structure of feature importance values
*/
export const getFeatureImportance = (row: any, mlResultsField: string): FeatureImportance[] => {
const featureNames: string[] = row[`${mlResultsField}.feature_importance.feature_name`];
const classNames: string[] = row[`${mlResultsField}.feature_importance.classes.class_name`];
const classImportance: number[] = row[`${mlResultsField}.feature_importance.classes.importance`];

return featureNames.map((fName, index) => {
const offset = featureNames.length * index;
const featureClassNames = classNames.slice(offset, offset + featureNames.length);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is causing the page to crash with regression jobs:

image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is because the expected shape coming from regression job is different compared to result from classification job here.
Screen Shot 2020-10-29 at 09 28 53

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed this is fixed by 96ccaef

const featureClassImportance = classImportance.slice(offset, offset + featureNames.length);
return {
feature_name: fName,
classes: featureClassNames.map((fClassName, fIndex) => {
return { class_name: fClassName, importance: featureClassImportance[fIndex] };
}),
};
});
};

/**
* Helper to transforms top classes flattened fields with arrays back to object structure
*
* @param row - EUI data grid data row
* @param mlResultsField - Data frame analytics results field
* @returns nested object structure of feature importance values
*/
export const getTopClasses = (row: any, mlResultsField: string): TopClasses => {
const classNames: string[] = row[`${mlResultsField}.top_classes.class_name`];
const classProbabilities: number[] = row[`${mlResultsField}.top_classes.class_probability`];
const classScores: number[] = row[`${mlResultsField}.top_classes.class_score`];

return classNames.map((className, index) => ({
class_name: className,
class_probability: classProbabilities[index],
class_score: classScores[index],
}));
};

export const useRenderCellValue = (
indexPattern: IndexPattern | undefined,
pagination: IndexPagination,
Expand Down Expand Up @@ -207,6 +253,15 @@ export const useRenderCellValue = (
return item[cId];
}

// For classification and regression results, we need to treat some fields with a custom transform.
if (cId === `${resultsField}.feature_importance`) {
return getFeatureImportance(fullItem, resultsField ?? DEFAULT_RESULTS_FIELD);
}

if (cId === `${resultsField}.top_classes`) {
return getTopClasses(fullItem, resultsField ?? DEFAULT_RESULTS_FIELD);
}

// Try if the field name is available as a nested field.
return getNestedProperty(tableItems[adjustedRowIndex], cId, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@ import { DEFAULT_SAMPLER_SHARD_SIZE } from '../../../../common/constants/field_h

import { ANALYSIS_CONFIG_TYPE, INDEX_STATUS } from '../../data_frame_analytics/common';

import { euiDataGridStyle, euiDataGridToolbarSettings } from './common';
import {
euiDataGridStyle,
euiDataGridToolbarSettings,
getFeatureImportance,
getTopClasses,
} from './common';
import { UseIndexDataReturnType } from './types';
import { DecisionPathPopover } from './feature_importance/decision_path_popover';
import { TopClasses } from '../../../../common/types/feature_importance';
import { FeatureImportance, TopClasses } from '../../../../common/types/feature_importance';
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
import { DataFrameAnalysisConfigType } from '../../../../common/types/data_frame_analytics';

Expand Down Expand Up @@ -118,16 +123,16 @@ export const DataGrid: FC<Props> = memo(
if (!row) return <div />;
// if resultsField for some reason is not available then use ml
const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD;
const parsedFIArray = row[mlResultsField].feature_importance;
const parsedFIArray: FeatureImportance[] = getFeatureImportance(row, mlResultsField);
let predictedValue: string | number | undefined;
let topClasses: TopClasses = [];
if (
predictionFieldName !== undefined &&
row &&
row[mlResultsField][predictionFieldName] !== undefined
row[`${mlResultsField}.${predictionFieldName}`] !== undefined
) {
predictedValue = row[mlResultsField][predictionFieldName];
topClasses = row[mlResultsField].top_classes;
predictedValue = row[`${mlResultsField}.${predictionFieldName}`];
topClasses = getTopClasses(row, mlResultsField);
}

return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ export const getDefaultFieldsFromJobCaps = (
name: `${resultsField}.${FEATURE_IMPORTANCE}`,
type: KBN_FIELD_TYPES.UNKNOWN,
});
// remove flattened feature importance fields
fields = fields.filter(
(field: any) => !field.name.includes(`${resultsField}.${FEATURE_IMPORTANCE}.`)
);
}

if ((numTopClasses ?? 0) > 0) {
Expand All @@ -221,6 +225,10 @@ export const getDefaultFieldsFromJobCaps = (
name: `${resultsField}.${TOP_CLASSES}`,
type: KBN_FIELD_TYPES.UNKNOWN,
});
// remove flattened top classes fields
fields = fields.filter(
(field: any) => !field.name.includes(`${resultsField}.${TOP_CLASSES}.`)
);
}

// Only need to add these fields if we didn't use dest index pattern to get the fields
Expand Down