diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index f7f9c3eaee34..ff95dd6959b7 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -137,7 +137,7 @@ public static Booster train( IEvaluation eval, int earlyStoppingRound, Booster booster) throws XGBoostError { - if(eval != null) { + if (eval != null) { return trainWithMultipleEvals(dtrain, params, round, watches, metrics, obj, new IEvaluation[]{eval}, earlyStoppingRound, booster); } else { @@ -164,6 +164,13 @@ public static BoosterResults trainWithResults( List mats = new ArrayList(); List logLines = new ArrayList(); + if (evals != null && evals.length == 0) { + throw new XGBoostError("Evaluation function array is empty, but not null."); + } else if (earlyStoppingRound != 0 && evals != null && evals.length > 1) { + Rabit.trackerPrint("Multiple evaluation functions provided, disabling early stopping."); + earlyStoppingRound = 0; + } + for (Map.Entry evalEntry : watches.entrySet()) { names.add(evalEntry.getKey()); mats.add(evalEntry.getValue()); @@ -209,7 +216,7 @@ public static BoosterResults trainWithResults( if (evalMats.length > 0) { float[] metricsOut = new float[evalMats.length]; String evalInfo = ""; - if (evals != null && evals.length > 0) { + if (evals != null) { for (int i = 0; i < evals.length; i++) { String evalLine = booster.evalSet(evalMats, evalNames, evals[i], metricsOut); evalInfo = evalInfo + " " + evalLine;