Skip to content

Commit

Permalink
[FEATURE][ML] auc_roc cannot be calculated when there are no inliers/… (
Browse files Browse the repository at this point in the history
#42853)

Also fixes a bug with the matching query for binary soft classification
  • Loading branch information
dimitris-athanasiou authored Jun 5, 2019
1 parent 69f5652 commit 44856a5
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.metrics.Percentiles;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -145,16 +146,23 @@ private String restLabelsAggName(ClassInfo classInfo) {
public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
Filter classAgg = aggs.get(evaluatedLabelAggName(classInfo));
Filter restAgg = aggs.get(restLabelsAggName(classInfo));
double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES));
double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES));
double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES),
"[" + getMetricName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]");
double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES),
"[" + getMetricName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]");
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = calculateAucScore(aucRocCurve);
return new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList());
}

private static double[] percentilesArray(Percentiles percentiles) {
private static double[] percentilesArray(Percentiles percentiles, String errorIfUndefined) {
double[] result = new double[99];
percentiles.forEach(percentile -> result[((int) percentile.getPercent()) - 1] = percentile.getValue());
percentiles.forEach(percentile -> {
if (Double.isNaN(percentile.getValue())) {
throw ExceptionsHelper.badRequestException(errorIfUndefined);
}
result[((int) percentile.getPercent()) - 1] = percentile.getValue();
});
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ public void evaluate(SearchResponse searchResponse, ActionListener<List<Evaluati

private class BinaryClassInfo implements SoftClassificationMetric.ClassInfo {

private QueryBuilder matchingQuery = QueryBuilders.queryStringQuery(actualField + ": 1 OR true");
private QueryBuilder matchingQuery = QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");

@Override
public String getName() {
Expand Down
2 changes: 2 additions & 0 deletions x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ integTestRunner {
'ml/evaluate_data_frame/Test given missing index',
'ml/evaluate_data_frame/Test given index does not exist',
'ml/evaluate_data_frame/Test given missing evaluation',
'ml/evaluate_data_frame/Test binary_soft_classifition auc_roc given actual_field is always true',
'ml/evaluate_data_frame/Test binary_soft_classifition auc_roc given actual_field is always false',
'ml/evaluate_data_frame/Test binary_soft_classification given evaluation with emtpy metrics',
'ml/evaluate_data_frame/Test binary_soft_classification given missing actual_field',
'ml/evaluate_data_frame/Test binary_soft_classification given missing predicted_probability_field',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@ protected void doExecute(Task task, EvaluateDataFrameAction.Request request,
);

client.execute(SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
searchResponse -> threadPool.generic().execute(() -> evaluation.evaluate(searchResponse, resultsListener)),
searchResponse -> threadPool.generic().execute(() -> {
try {
evaluation.evaluate(searchResponse, resultsListener);
} catch (Exception e) {
listener.onFailure(e);
};
}),
listener::onFailure
));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ setup:
{
"is_outlier": false,
"is_outlier_int": 0,
"outlier_score": 0.0
"outlier_score": 0.0,
"all_true_field": true,
"all_false_field": false
}
- do:
Expand All @@ -17,7 +19,9 @@ setup:
{
"is_outlier": false,
"is_outlier_int": 0,
"outlier_score": 0.2
"outlier_score": 0.2,
"all_true_field": true,
"all_false_field": false
}
- do:
Expand All @@ -27,7 +31,9 @@ setup:
{
"is_outlier": false,
"is_outlier_int": 0,
"outlier_score": 0.3
"outlier_score": 0.3,
"all_true_field": true,
"all_false_field": false
}
- do:
Expand All @@ -37,7 +43,9 @@ setup:
{
"is_outlier": true,
"is_outlier_int": 1,
"outlier_score": 0.3
"outlier_score": 0.3,
"all_true_field": true,
"all_false_field": false
}
- do:
Expand All @@ -47,7 +55,9 @@ setup:
{
"is_outlier": true,
"is_outlier_int": 1,
"outlier_score": 0.4
"outlier_score": 0.4,
"all_true_field": true,
"all_false_field": false
}
- do:
Expand All @@ -57,7 +67,9 @@ setup:
{
"is_outlier": true,
"is_outlier_int": 1,
"outlier_score": 0.5
"outlier_score": 0.5,
"all_true_field": true,
"all_false_field": false
}
- do:
Expand All @@ -67,7 +79,9 @@ setup:
{
"is_outlier": true,
"is_outlier_int": 1,
"outlier_score": 0.9
"outlier_score": 0.9,
"all_true_field": true,
"all_false_field": false
}
- do:
Expand All @@ -77,7 +91,9 @@ setup:
{
"is_outlier": true,
"is_outlier_int": 1,
"outlier_score": 0.95
"outlier_score": 0.95,
"all_true_field": true,
"all_false_field": false
}
# This document misses the required fields and should be ignored
Expand Down Expand Up @@ -152,6 +168,44 @@ setup:
- match: { binary_soft_classification.auc_roc.score: 0.9899 }
- is_true: binary_soft_classification.auc_roc.curve

---
"Test binary_soft_classifition auc_roc given actual_field is always true":
- do:
catch: /\[auc_roc\] requires at least one actual_field to have a different value than \[true\]/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"binary_soft_classification": {
"actual_field": "all_true_field",
"predicted_probability_field": "outlier_score",
"metrics": {
"auc_roc": {}
}
}
}
}
---
"Test binary_soft_classifition auc_roc given actual_field is always false":
- do:
catch: /\[auc_roc\] requires at least one actual_field to have the value \[true\]/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"binary_soft_classification": {
"actual_field": "all_false_field",
"predicted_probability_field": "outlier_score",
"metrics": {
"auc_roc": {}
}
}
}
}
---
"Test binary_soft_classifition precision":
- do:
Expand Down

0 comments on commit 44856a5

Please sign in to comment.