Skip to content

Commit

Permalink
[jvm-packages] fix early stopping doesn't work even without custom_ev…
Browse files Browse the repository at this point in the history
…al setting (#6738)

* [jvm-packages] fix early stopping doesn't work even without custom_eval setting

* remove debug info

* resolve comment
  • Loading branch information
wbo4958 authored Mar 7, 2021
1 parent 5ae7f99 commit 49c22c2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds
if (numEarlyStoppingRounds > 0 &&
!overridedParams.contains("maximize_evaluation_metrics")) {
if (overridedParams.contains("custom_eval")) {
if (overridedParams.getOrElse("custom_eval", null) != null) {
throw new IllegalArgumentException("custom_eval does not support early stopping")
}
val eval_metric = overridedParams("eval_metric").toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,26 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
waitForSparkContextShutdown()
}
}

test("custom_eval does not support early stopping") {
val paramMap = Map("eta" -> "0.1", "custom_eval" -> new EvalError, "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "num_early_stopping_rounds" -> 2)
val trainingDF = buildDataFrame(MultiClassification.train)

val thrown = intercept[IllegalArgumentException] {
new XGBoostClassifier(paramMap).fit(trainingDF)
}

assert(thrown.getMessage.contains("custom_eval does not support early stopping"))
}

test("early stopping should work without custom_eval setting") {
val paramMap = Map("eta" -> "0.1", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "num_early_stopping_rounds" -> 2)
val trainingDF = buildDataFrame(MultiClassification.train)

new XGBoostClassifier(paramMap).fit(trainingDF)
}
}

0 comments on commit 49c22c2

Please sign in to comment.