From 0d27d55706c167171050318248ceef280bf096b3 Mon Sep 17 00:00:00 2001 From: "Leah E. Cole" <6719667+leahecole@users.noreply.github.com> Date: Thu, 30 Jan 2020 14:09:54 -0800 Subject: [PATCH] docs(samples): add feature importance to predict sample (#277) * Add feature importance to predict sample * Fix license header * fix: skip tensorflow linkinator - flaky * Add bens map suggestion * Fix lint and errors Co-authored-by: Benjamin E. Coe --- automl/snippets/tables/predict.v1beta1.js | 24 +++++++++++++++- .../test/automlTablesPredict.v1beta1.test.js | 28 ++++++++++--------- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/automl/snippets/tables/predict.v1beta1.js b/automl/snippets/tables/predict.v1beta1.js index 56d09148c56..eae7d14234f 100644 --- a/automl/snippets/tables/predict.v1beta1.js +++ b/automl/snippets/tables/predict.v1beta1.js @@ -63,7 +63,11 @@ async function main( // Params is additional domain-specific parameters. // Currently there is no additional parameters supported. client - .predict({name: modelFullId, payload: payload, params: {}}) + .predict({ + name: modelFullId, + payload: payload, + params: {feature_importance: true}, + }) .then(responses => { console.log(responses); console.log(`Prediction results:`); @@ -71,6 +75,24 @@ async function main( for (const result of responses[0].payload) { console.log(`Predicted class name: ${result.displayName}`); console.log(`Predicted class score: ${result.tables.score}`); + + // Get features of top importance + const featureList = result.tables.tablesModelColumnInfo.map( + columnInfo => { + return { + importance: columnInfo.featureImportance, + displayName: columnInfo.columnDisplayName, + }; + } + ); + // Sort features by their importance, highest importance first + featureList.sort(function(a, b) { + return b.importance - a.importance; + }); + + // Print top 10 important features + console.log('Features of top importance'); + console.log(featureList.slice(0, 10)); } }) .catch(err => { diff --git a/automl/snippets/test/automlTablesPredict.v1beta1.test.js b/automl/snippets/test/automlTablesPredict.v1beta1.test.js index 24ff53f21cc..cb8980275b9 100644 --- a/automl/snippets/test/automlTablesPredict.v1beta1.test.js +++ b/automl/snippets/test/automlTablesPredict.v1beta1.test.js @@ -1,16 +1,17 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/** + * Copyright 2019 Google LLC + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ 'use strict'; @@ -42,6 +43,7 @@ describe('Tables PredictionAPI', () => { // Run single prediction on predictTest.csv in resource folder const output = exec(`${cmdPredict} predict "${modelId}" "${filePath}"`); assert.match(output, /Prediction results:/); + assert.match(output, /Features of top importance:/); }); it.skip(`should perform batch prediction using GCS as source and