diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index f838512af67e..68a42fd12ef4 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -149,7 +149,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds if (numEarlyStoppingRounds > 0 && !overridedParams.contains("maximize_evaluation_metrics")) { - if (overridedParams.contains("custom_eval")) { + if (overridedParams.getOrElse("custom_eval", null) != null) { throw new IllegalArgumentException("custom_eval does not support early stopping") } val eval_metric = overridedParams("eval_metric").toString diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala index c802fc0b0448..50596c69f7ae 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala @@ -78,4 +78,26 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { waitForSparkContextShutdown() } } + + test("custom_eval does not support early stopping") { + val paramMap = Map("eta" -> "0.1", "custom_eval" -> new EvalError, "silent" -> "1", + "objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5, + "num_workers" -> numWorkers, "num_early_stopping_rounds" -> 2) + val trainingDF = buildDataFrame(MultiClassification.train) + + val thrown = intercept[IllegalArgumentException] { + new XGBoostClassifier(paramMap).fit(trainingDF) + } + + assert(thrown.getMessage.contains("custom_eval does not support early stopping")) + } + + test("early stopping should work without custom_eval setting") { + val paramMap = Map("eta" -> "0.1", "silent" -> "1", + "objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5, + "num_workers" -> numWorkers, "num_early_stopping_rounds" -> 2) + val trainingDF = buildDataFrame(MultiClassification.train) + + new XGBoostClassifier(paramMap).fit(trainingDF) + } }