[jvm-packages] fix early stopping doesn't work even without custom_eval setting (#6738)

* [jvm-packages] fix early stopping doesn't work even without custom_eval setting

* remove debug info

* resolve comment
This commit is contained in:
Bobby Wang 2021-03-07 12:19:40 +08:00 committed by GitHub
parent 5ae7f9944b
commit 49c22c23b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 1 deletions

View File

@ -149,7 +149,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds
if (numEarlyStoppingRounds > 0 &&
!overridedParams.contains("maximize_evaluation_metrics")) {
if (overridedParams.contains("custom_eval")) {
if (overridedParams.getOrElse("custom_eval", null) != null) {
throw new IllegalArgumentException("custom_eval does not support early stopping")
}
val eval_metric = overridedParams("eval_metric").toString

View File

@ -78,4 +78,26 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
waitForSparkContextShutdown()
}
}
test("custom_eval does not support early stopping") {
val paramMap = Map("eta" -> "0.1", "custom_eval" -> new EvalError, "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "num_early_stopping_rounds" -> 2)
val trainingDF = buildDataFrame(MultiClassification.train)
val thrown = intercept[IllegalArgumentException] {
new XGBoostClassifier(paramMap).fit(trainingDF)
}
assert(thrown.getMessage.contains("custom_eval does not support early stopping"))
}
test("early stopping should work without custom_eval setting") {
val paramMap = Map("eta" -> "0.1", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "num_early_stopping_rounds" -> 2)
val trainingDF = buildDataFrame(MultiClassification.train)
new XGBoostClassifier(paramMap).fit(trainingDF)
}
}