Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Data Frame Analytics: Fix feature importance cell value and decision path chart #82011

Merged
merged 7 commits into from
Oct 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion x-pack/plugins/ml/common/types/feature_importance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ export interface ClassFeatureImportance {
class_name: string | boolean;
importance: number;
}

// TODO We should separate the interface because classes/importance
// isn't both optional but either/or.
export interface FeatureImportance {
feature_name: string;
importance?: number;
classes?: ClassFeatureImportance[];
importance?: number;
}

export interface TopClass {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { EuiDataGridSorting } from '@elastic/eui';

import { multiColumnSortFactory } from './common';

describe('Transform: Define Pivot Common', () => {
describe('Data Frame Analytics: Data Grid Common', () => {
test('multiColumnSortFactory()', () => {
const data = [
{ s: 'a', n: 1 },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ import {
KBN_FIELD_TYPES,
} from '../../../../../../../src/plugins/data/public';

import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
import { extractErrorMessage } from '../../../../common/util/errors';
import { FeatureImportance, TopClasses } from '../../../../common/types/feature_importance';

import {
BASIC_NUMERICAL_TYPES,
Expand Down Expand Up @@ -158,6 +160,90 @@ export const getDataGridSchemaFromKibanaFieldType = (
return schema;
};

const getClassName = (className: string, isClassTypeBoolean: boolean) => {
if (isClassTypeBoolean) {
return className === 'true';
}

return className;
};
/**
* Helper to transform feature importance flattened fields with arrays back to object structure
*
* @param row - EUI data grid data row
* @param mlResultsField - Data frame analytics results field
* @returns nested object structure of feature importance values
*/
export const getFeatureImportance = (
row: Record<string, any>,
mlResultsField: string,
isClassTypeBoolean = false
): FeatureImportance[] => {
const featureNames: string[] | undefined =
row[`${mlResultsField}.feature_importance.feature_name`];
const classNames: string[] | undefined =
row[`${mlResultsField}.feature_importance.classes.class_name`];
const classImportance: number[] | undefined =
row[`${mlResultsField}.feature_importance.classes.importance`];

if (featureNames === undefined) {
return [];
}

// return object structure for classification job
if (classNames !== undefined && classImportance !== undefined) {
const overallClassNames = classNames?.slice(0, classNames.length / featureNames.length);

return featureNames.map((fName, index) => {
const offset = overallClassNames.length * index;
const featureClassImportance = classImportance.slice(
offset,
offset + overallClassNames.length
);
return {
feature_name: fName,
classes: overallClassNames.map((fClassName, fIndex) => {
return {
class_name: getClassName(fClassName, isClassTypeBoolean),
importance: featureClassImportance[fIndex],
};
}),
};
});
}

// return object structure for regression job
const importance: number[] = row[`${mlResultsField}.feature_importance.importance`];
return featureNames.map((fName, index) => ({
feature_name: fName,
importance: importance[index],
}));
};

/**
* Helper to transforms top classes flattened fields with arrays back to object structure
*
* @param row - EUI data grid data row
* @param mlResultsField - Data frame analytics results field
* @returns nested object structure of feature importance values
*/
export const getTopClasses = (row: Record<string, any>, mlResultsField: string): TopClasses => {
const classNames: string[] | undefined = row[`${mlResultsField}.top_classes.class_name`];
const classProbabilities: number[] | undefined =
row[`${mlResultsField}.top_classes.class_probability`];
const classScores: number[] | undefined = row[`${mlResultsField}.top_classes.class_score`];

if (classNames === undefined || classProbabilities === undefined || classScores === undefined) {
return [];
}

return classNames.map((className, index) => ({
class_name: className,
class_probability: classProbabilities[index],
class_score: classScores[index],
}));
};

export const useRenderCellValue = (
indexPattern: IndexPattern | undefined,
pagination: IndexPagination,
Expand Down Expand Up @@ -207,6 +293,15 @@ export const useRenderCellValue = (
return item[cId];
}

// For classification and regression results, we need to treat some fields with a custom transform.
if (cId === `${resultsField}.feature_importance`) {
return getFeatureImportance(fullItem, resultsField ?? DEFAULT_RESULTS_FIELD);
}

if (cId === `${resultsField}.top_classes`) {
return getTopClasses(fullItem, resultsField ?? DEFAULT_RESULTS_FIELD);
}

// Try if the field name is available as a nested field.
return getNestedProperty(tableItems[adjustedRowIndex], cId, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@ import { DEFAULT_SAMPLER_SHARD_SIZE } from '../../../../common/constants/field_h

import { ANALYSIS_CONFIG_TYPE, INDEX_STATUS } from '../../data_frame_analytics/common';

import { euiDataGridStyle, euiDataGridToolbarSettings } from './common';
import {
euiDataGridStyle,
euiDataGridToolbarSettings,
getFeatureImportance,
getTopClasses,
} from './common';
import { UseIndexDataReturnType } from './types';
import { DecisionPathPopover } from './feature_importance/decision_path_popover';
import { TopClasses } from '../../../../common/types/feature_importance';
import { FeatureImportance, TopClasses } from '../../../../common/types/feature_importance';
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
import { DataFrameAnalysisConfigType } from '../../../../common/types/data_frame_analytics';

Expand Down Expand Up @@ -118,18 +123,28 @@ export const DataGrid: FC<Props> = memo(
if (!row) return <div />;
// if resultsField for some reason is not available then use ml
const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD;
const parsedFIArray = row[mlResultsField].feature_importance;
let predictedValue: string | number | undefined;
let topClasses: TopClasses = [];
if (
predictionFieldName !== undefined &&
row &&
row[mlResultsField][predictionFieldName] !== undefined
row[`${mlResultsField}.${predictionFieldName}`] !== undefined
) {
predictedValue = row[mlResultsField][predictionFieldName];
topClasses = row[mlResultsField].top_classes;
predictedValue = row[`${mlResultsField}.${predictionFieldName}`];
topClasses = getTopClasses(row, mlResultsField);
}

const isClassTypeBoolean = topClasses.reduce(
(p, c) => typeof c.class_name === 'boolean' || p,
false
);

const parsedFIArray: FeatureImportance[] = getFeatureImportance(
row,
mlResultsField,
isClassTypeBoolean
);

return (
<DecisionPathPopover
analysisType={analysisType}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ export const getDefaultFieldsFromJobCaps = (
name: `${resultsField}.${FEATURE_IMPORTANCE}`,
type: KBN_FIELD_TYPES.UNKNOWN,
});
// remove flattened feature importance fields
fields = fields.filter(
(field: any) => !field.name.includes(`${resultsField}.${FEATURE_IMPORTANCE}.`)
);
}

if ((numTopClasses ?? 0) > 0) {
Expand All @@ -221,6 +225,10 @@ export const getDefaultFieldsFromJobCaps = (
name: `${resultsField}.${TOP_CLASSES}`,
type: KBN_FIELD_TYPES.UNKNOWN,
});
// remove flattened top classes fields
fields = fields.filter(
(field: any) => !field.name.includes(`${resultsField}.${TOP_CLASSES}.`)
);
}

// Only need to add these fields if we didn't use dest index pattern to get the fields
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ export const getIndexData = async (
index: jobConfig.dest.index,
body: {
fields: ['*'],
_source: [],
_source: false,
query: searchQuery,
from: pageIndex * pageSize,
size: pageSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ interface Props {
}

export const ExplorationResultsTable: FC<Props> = React.memo(
({ indexPattern, jobConfig, jobStatus, needsDestIndexPattern, searchQuery }) => {
({ indexPattern, jobConfig, needsDestIndexPattern, searchQuery }) => {
const {
services: {
mlServices: { mlApiServices },
Expand Down