diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index e6fa7d709cff..41a355939bcc 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -185,6 +185,38 @@ public void testDescendMetricsWithBoundaryCondition() { } } + @Test + public void testEarlyStoppingForMultipleMetrics() { + Map paramMap = new HashMap() { + { + put("max_depth", 3); + put("silent", 1); + put("objective", "binary:logistic"); + put("maximize_evaluation_metrics", "true"); + } + }; + float[][] metrics = new float[3][5]; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 5; j++) { + metrics[0][j] = j; + } + } + boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + TestCase.assertTrue(onTrack); + for (int i = 0; i < 5; i++) { + metrics[0][i] = 5 - i; + } + // when we have multiple datasets, the training metrics is not considered + onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + TestCase.assertTrue(onTrack); + for (int i = 0; i < 5; i++) { + metrics[1][i] = 5 - i; + } + // if any metrics off, we need to stop + onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + TestCase.assertFalse(onTrack); + } + @Test public void testDescendMetrics() { Map paramMap = new HashMap() {