[jvm-packages] Automatically set maximize_evaluation_metrics if not explicitly given in XGBoost4J-Spark (#4446)

* Automatically set maximize_evaluation_metrics if not explicitly given.

* When custom_eval is set, require maximize_evaluation_metrics.

* Update documents on early stop in XGBoost4J-Spark.

* Fix code error.
This commit is contained in:
Shaochen Shi
2019-05-09 14:49:44 -05:00
committed by Nan Zhu
parent 8da4907e89
commit 18e4fc3690
3 changed files with 20 additions and 7 deletions

View File

@@ -24,6 +24,7 @@ import scala.util.Random
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.io.FileUtils
@@ -157,13 +158,21 @@ object XGBoost extends Serializable {
try {
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
.map(_.toString.toInt).getOrElse(0)
if (numEarlyStoppingRounds > 0) {
if (!params.contains("maximize_evaluation_metrics")) {
throw new IllegalArgumentException("maximize_evaluation_metrics has to be specified")
val overridedParams = if (numEarlyStoppingRounds > 0 &&
!params.contains("maximize_evaluation_metrics")) {
if (params.contains("custom_eval")) {
throw new IllegalArgumentException("maximize_evaluation_metrics has to be "
+ "specified when custom_eval is set")
}
val eval_metric = params("eval_metric").toString
val maximize = LearningTaskParams.evalMetricsToMaximize contains eval_metric
logger.info("parameter \"maximize_evaluation_metrics\" is set to " + maximize)
params + ("maximize_evaluation_metrics" -> maximize)
} else {
params
}
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
val booster = SXGBoost.train(watches.toMap("train"), params, round,
val booster = SXGBoost.train(watches.toMap("train"), overridedParams, round,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)

View File

@@ -111,6 +111,10 @@ private[spark] object LearningTaskParams {
val supportedObjectiveType = HashSet("regression", "classification")
val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss",
"auc", "aucpr", "ndcg", "map", "gamma-deviance")
val evalMetricsToMaximize = HashSet("auc", "aucpr", "ndcg", "map")
val evalMetricsToMinimize = HashSet("rmse", "mae", "logloss", "error", "merror",
"mlogloss", "gamma-deviance")
val supportedEvalMetrics = evalMetricsToMaximize union evalMetricsToMinimize
}