Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Return total SHAP per feature as a new result type #1387

Merged
merged 19 commits into from
Aug 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
regression. (See {ml-pull}1340[#1340].)
* Improvement in handling large inference model definitions. (See {ml-pull}1349[#1349].)
* Add a peak_model_bytes field to model_size_stats. (See {ml-pull}1389[#1389].)
* Calculate total feature importance as a new result type. (See {ml-pull}1387[#1387].)

=== Bug Fixes

Expand Down
7 changes: 7 additions & 0 deletions include/api/CDataFrameAnalysisRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@

#include <api/CDataFrameAnalysisInstrumentation.h>
#include <api/CInferenceModelDefinition.h>
#include <api/CInferenceModelMetadata.h>
#include <api/ImportExport.h>

#include <rapidjson/fwd.h>

#include <boost/optional.hpp>

#include <cstddef>
#include <functional>
#include <memory>
Expand Down Expand Up @@ -66,6 +69,7 @@ class API_EXPORT CDataFrameAnalysisRunner {
using TProgressRecorder = std::function<void(double)>;
using TStrVecVec = std::vector<TStrVec>;
using TInferenceModelDefinitionUPtr = std::unique_ptr<CInferenceModelDefinition>;
using TOptionalInferenceModelMetadata = boost::optional<const CInferenceModelMetadata&>;

public:
//! The intention is that concrete objects of this hierarchy are constructed
Expand Down Expand Up @@ -141,6 +145,9 @@ class API_EXPORT CDataFrameAnalysisRunner {
virtual TInferenceModelDefinitionUPtr
inferenceModelDefinition(const TStrVec& fieldNames, const TStrVecVec& categoryNames) const;

//! \return A serialisable metadata of the trained model.
virtual TOptionalInferenceModelMetadata inferenceModelMetadata() const;

//! \return Reference to the analysis instrumentation.
virtual const CDataFrameAnalysisInstrumentation& instrumentation() const = 0;
//! \return Reference to the analysis instrumentation.
Expand Down
2 changes: 2 additions & 0 deletions include/api/CDataFrameAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class API_EXPORT CDataFrameAnalyzer {
core::CRapidJsonConcurrentLineWriter& writer) const;
void writeInferenceModel(const CDataFrameAnalysisRunner& analysis,
core::CRapidJsonConcurrentLineWriter& writer) const;
void writeInferenceModelMetadata(const CDataFrameAnalysisRunner& analysis,
core::CRapidJsonConcurrentLineWriter& writer) const;

private:
// This has values: -2 (unset), -1 (missing), >= 0 (control field index).
Expand Down
7 changes: 7 additions & 0 deletions include/api/CDataFrameTrainBoostedTreeClassifierRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <core/CSmallVector.h>

#include <api/CDataFrameTrainBoostedTreeRunner.h>
#include <api/CInferenceModelMetadata.h>
#include <api/ImportExport.h>

#include <rapidjson/fwd.h>
Expand Down Expand Up @@ -40,6 +41,8 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
static const std::string NUM_TOP_CLASSES;
static const std::string PREDICTION_FIELD_TYPE;
static const std::string CLASS_ASSIGNMENT_OBJECTIVE;
static const std::string CLASSES_FIELD_NAME;
static const std::string CLASS_NAME_FIELD_NAME;
static const TStrVec CLASS_ASSIGNMENT_OBJECTIVE_VALUES;

public:
Expand Down Expand Up @@ -70,6 +73,9 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
inferenceModelDefinition(const TStrVec& fieldNames,
const TStrVecVec& categoryNames) const override;

//! \return A serialisable metadata of the trained regression model.
TOptionalInferenceModelMetadata inferenceModelMetadata() const override;

private:
static TLossFunctionUPtr loss(std::size_t numberClasses);

Expand All @@ -82,6 +88,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
private:
std::size_t m_NumTopClasses;
EPredictionFieldType m_PredictionFieldType;
mutable CInferenceModelMetadata m_InferenceModelMetadata;
};

//! \brief Makes a core::CDataFrame boosted tree classification runner.
Expand Down
6 changes: 6 additions & 0 deletions include/api/CDataFrameTrainBoostedTreeRegressionRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <maths/CBoostedTreeLoss.h>

#include <api/CDataFrameTrainBoostedTreeRunner.h>
#include <api/CInferenceModelMetadata.h>
#include <api/ImportExport.h>

#include <rapidjson/fwd.h>
Expand Down Expand Up @@ -51,10 +52,15 @@ class API_EXPORT CDataFrameTrainBoostedTreeRegressionRunner final
TInferenceModelDefinitionUPtr
inferenceModelDefinition(const TStrVec& fieldNames,
const TStrVecVec& categoryNameMap) const override;
//! \return A serialisable metadata of the trained regression model.
TOptionalInferenceModelMetadata inferenceModelMetadata() const override;

private:
void validate(const core::CDataFrame& frame,
std::size_t dependentVariableColumn) const override;

private:
mutable CInferenceModelMetadata m_InferenceModelMetadata;
};

//! \brief Makes a core::CDataFrame boosted tree regression runner.
Expand Down
67 changes: 67 additions & 0 deletions include/api/CInferenceModelMetadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
#ifndef INCLUDED_ml_api_CInferenceModelMetadata_h
#define INCLUDED_ml_api_CInferenceModelMetadata_h

#include <maths/CBasicStatistics.h>
#include <maths/CLinearAlgebraEigen.h>

#include <api/CInferenceModelDefinition.h>
#include <api/ImportExport.h>

#include <string>

namespace ml {
namespace api {

//! \brief Class controls the serialization of the model meta information
//! (such as totol feature importance) into JSON format.
class API_EXPORT CInferenceModelMetadata {
public:
static const std::string JSON_CLASS_NAME_TAG;
static const std::string JSON_CLASSES_TAG;
static const std::string JSON_FEATURE_NAME_TAG;
static const std::string JSON_IMPORTANCE_TAG;
static const std::string JSON_MAX_TAG;
static const std::string JSON_MEAN_MAGNITUDE_TAG;
static const std::string JSON_MIN_TAG;
static const std::string JSON_MODEL_METADATA_TAG;
static const std::string JSON_TOTAL_FEATURE_IMPORTANCE_TAG;

public:
using TVector = maths::CDenseVector<double>;
using TStrVec = std::vector<std::string>;
using TRapidJsonWriter = core::CRapidJsonConcurrentLineWriter;

public:
//! Writes metadata using \p writer.
void write(TRapidJsonWriter& writer) const;
void columnNames(const TStrVec& columnNames);
void classValues(const TStrVec& classValues);
const std::string& typeString() const;
//! Add importances \p values to the feature with index \p i to calculate total feature importance.
//! Total feature importance is the mean of the magnitudes of importances for individual data points.
void addToFeatureImportance(std::size_t i, const TVector& values);

private:
using TMeanVarAccumulator = maths::CBasicStatistics::SSampleMeanVar<TVector>::TAccumulator;
using TMinMaxAccumulator = std::vector<maths::CBasicStatistics::CMinMax<double>>;
using TSizeMeanVarAccumulatorUMap = std::unordered_map<std::size_t, TMeanVarAccumulator>;
using TSizeMinMaxAccumulatorUMap = std::unordered_map<std::size_t, TMinMaxAccumulator>;

private:
void writeTotalFeatureImportance(TRapidJsonWriter& writer) const;

private:
TSizeMeanVarAccumulatorUMap m_TotalShapValuesMeanVar;
TSizeMinMaxAccumulatorUMap m_TotalShapValuesMinMax;
TStrVec m_ColumnNames;
TStrVec m_ClassValues;
};
}
}

#endif //INCLUDED_ml_api_CInferenceModelMetadata_h
2 changes: 1 addition & 1 deletion include/maths/CBasicStatistics.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class MATHS_EXPORT CBasicStatistics {

if (ORDER > 1) {
T r{x - s_Moments[0]};
T r2{r * r};
T r2{las::componentwise(r) * las::componentwise(r)};
T dMean{mean - s_Moments[0]};
T dMean2{las::componentwise(dMean) * las::componentwise(dMean)};
T variance{s_Moments[1]};
Expand Down
3 changes: 3 additions & 0 deletions include/maths/CTreeShapFeatureImportance.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
//! Get the maximum depth of any tree in \p forest.
static std::size_t depth(const TTreeVec& forest);

//! Get the column names.
const TStrVec& columnNames() const;

private:
//! Collects the elements of the path through decision tree that are updated together
struct SPathElement {
Expand Down
5 changes: 5 additions & 0 deletions lib/api/CDataFrameAnalysisRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ CDataFrameAnalysisRunner::inferenceModelDefinition(const TStrVec& /*fieldNames*/
return TInferenceModelDefinitionUPtr();
}

CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
CDataFrameAnalysisRunner::inferenceModelMetadata() const {
return TOptionalInferenceModelMetadata();
}

CDataFrameAnalysisRunnerFactory::TRunnerUPtr
CDataFrameAnalysisRunnerFactory::make(const CDataFrameAnalysisSpecification& spec) const {
auto result = this->makeImpl(spec);
Expand Down
17 changes: 17 additions & 0 deletions lib/api/CDataFrameAnalyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ void CDataFrameAnalyzer::run() {
analysisRunner->waitToFinish();
this->writeInferenceModel(*analysisRunner, outputWriter);
this->writeResultsOf(*analysisRunner, outputWriter);
// TODO reactivate once Java parsing is ready
// this->writeInferenceModelMetadata(*analysisRunner, outputWriter);
}
}

Expand Down Expand Up @@ -286,6 +288,21 @@ void CDataFrameAnalyzer::writeInferenceModel(const CDataFrameAnalysisRunner& ana
writer.flush();
}

void CDataFrameAnalyzer::writeInferenceModelMetadata(const CDataFrameAnalysisRunner& analysis,
core::CRapidJsonConcurrentLineWriter& writer) const {
// Write model meta information
auto modelMetadata = analysis.inferenceModelMetadata();
if (modelMetadata) {
writer.StartObject();
writer.Key(modelMetadata->typeString());
writer.StartObject();
modelMetadata->write(writer);
writer.EndObject();
writer.EndObject();
}
writer.flush();
}

void CDataFrameAnalyzer::writeResultsOf(const CDataFrameAnalysisRunner& analysis,
core::CRapidJsonConcurrentLineWriter& writer) const {

Expand Down
52 changes: 44 additions & 8 deletions lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <maths/CBoostedTreeLoss.h>
#include <maths/CDataFramePredictiveModel.h>
#include <maths/CDataFrameUtils.h>
#include <maths/CLinearAlgebraEigen.h>
#include <maths/COrderings.h>
#include <maths/CTools.h>
#include <maths/CTreeShapFeatureImportance.h>
Expand All @@ -41,7 +42,6 @@ const std::string IS_TRAINING_FIELD_NAME{"is_training"};
const std::string PREDICTION_PROBABILITY_FIELD_NAME{"prediction_probability"};
const std::string PREDICTION_SCORE_FIELD_NAME{"prediction_score"};
const std::string TOP_CLASSES_FIELD_NAME{"top_classes"};
const std::string CLASS_NAME_FIELD_NAME{"class_name"};
const std::string CLASS_PROBABILITY_FIELD_NAME{"class_probability"};
const std::string CLASS_SCORE_FIELD_NAME{"class_score"};

Expand Down Expand Up @@ -162,7 +162,9 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
}

if (featureImportance != nullptr) {
int numberClasses{static_cast<int>(classValues.size())};
std::size_t numberClasses{classValues.size()};
m_InferenceModelMetadata.columnNames(featureImportance->columnNames());
m_InferenceModelMetadata.classValues(classValues);
featureImportance->shap(
row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& featureNames,
Expand All @@ -175,20 +177,47 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
writer.Key(FEATURE_NAME_FIELD_NAME);
writer.String(featureNames[i]);
if (shap[i].size() == 1) {
writer.Key(IMPORTANCE_FIELD_NAME);
writer.Double(shap[i](0));
// output feature importance for individual classes in binary case
writer.Key(CLASSES_FIELD_NAME);
writer.StartArray();
for (std::size_t j = 0; j < numberClasses; ++j) {
double importance{(j == predictedClassId)
? shap[i](0)
: -shap[i](0)};
writer.StartObject();
writer.Key(CLASS_NAME_FIELD_NAME);
writer.String(classValues[j]);
writer.Key(IMPORTANCE_FIELD_NAME);
writer.Double(importance);
writer.EndObject();
}
writer.EndArray();
} else {
for (int j = 0; j < shap[i].size() && j < numberClasses; ++j) {
writer.Key(classValues[j]);
// output feature importance for individual classes in multiclass case
writer.Key(CLASSES_FIELD_NAME);
writer.StartArray();
for (std::size_t j = 0;
j < shap[i].size() && j < numberClasses; ++j) {
writer.StartObject();
writer.Key(CLASS_NAME_FIELD_NAME);
writer.String(classValues[j]);
writer.Key(IMPORTANCE_FIELD_NAME);
writer.Double(shap[i](j));
writer.EndObject();
}
writer.Key(CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME);
writer.Double(shap[i].lpNorm<1>());
writer.EndArray();
}
writer.EndObject();
}
}
writer.EndArray();

for (std::size_t i = 0; i < shap.size(); ++i) {
tveasey marked this conversation as resolved.
Show resolved Hide resolved
if (shap[i].lpNorm<1>() != 0) {
const_cast<CDataFrameTrainBoostedTreeClassifierRunner*>(this)
->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
}
}
});
}
writer.EndObject();
Expand Down Expand Up @@ -257,6 +286,11 @@ CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelDefinition(
return std::make_unique<CInferenceModelDefinition>(builder.build());
}

CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelMetadata() const {
return m_InferenceModelMetadata;
}

// clang-format off
// The MAX_NUMBER_CLASSES must match the value used in the Java code. See the
// MAX_DEPENDENT_VARIABLE_CARDINALITY in the x-pack classification code.
Expand Down Expand Up @@ -291,5 +325,7 @@ CDataFrameTrainBoostedTreeClassifierRunnerFactory::makeImpl(
}

const std::string CDataFrameTrainBoostedTreeClassifierRunnerFactory::NAME{"classification"};
const std::string CDataFrameTrainBoostedTreeClassifierRunner::CLASSES_FIELD_NAME{"classes"};
const std::string CDataFrameTrainBoostedTreeClassifierRunner::CLASS_NAME_FIELD_NAME{"class_name"};
}
}
21 changes: 17 additions & 4 deletions lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ void CDataFrameTrainBoostedTreeRegressionRunner::writeOneRow(
writer.Bool(maths::CDataFrameUtils::isMissing(row[columnHoldingDependentVariable]) == false);
auto featureImportance = tree.shap();
if (featureImportance != nullptr) {
m_InferenceModelMetadata.columnNames(featureImportance->columnNames());
featureImportance->shap(
row, [&writer](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& featureNames,
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
row, [&writer, this](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& featureNames,
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
writer.Key(FEATURE_IMPORTANCE_FIELD_NAME);
writer.StartArray();
for (auto i : indices) {
Expand All @@ -126,6 +127,13 @@ void CDataFrameTrainBoostedTreeRegressionRunner::writeOneRow(
}
}
writer.EndArray();

for (int i = 0; i < shap.size(); ++i) {
if (shap[i].lpNorm<1>() != 0) {
const_cast<CDataFrameTrainBoostedTreeRegressionRunner*>(this)
->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
}
}
});
}
writer.EndObject();
Expand All @@ -145,6 +153,11 @@ CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelDefinition(
return std::make_unique<CInferenceModelDefinition>(builder.build());
}

CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelMetadata() const {
return TOptionalInferenceModelMetadata(m_InferenceModelMetadata);
}

// clang-format off
const std::string CDataFrameTrainBoostedTreeRegressionRunner::STRATIFIED_CROSS_VALIDATION{"stratified_cross_validation"};
const std::string CDataFrameTrainBoostedTreeRegressionRunner::LOSS_FUNCTION{"loss_function"};
Expand All @@ -160,7 +173,7 @@ const std::string& CDataFrameTrainBoostedTreeRegressionRunnerFactory::name() con

CDataFrameTrainBoostedTreeRegressionRunnerFactory::TRunnerUPtr
CDataFrameTrainBoostedTreeRegressionRunnerFactory::makeImpl(const CDataFrameAnalysisSpecification&) const {
HANDLE_FATAL(<< "Input error: classification has a non-optional parameter '"
HANDLE_FATAL(<< "Input error: regression has a non-optional parameter '"
<< CDataFrameTrainBoostedTreeRunner::DEPENDENT_VARIABLE_NAME << "'.")
return nullptr;
}
Expand Down
Loading