[jvm-packages] Automatically set maximize_evaluation_metrics if not explicitly given in XGBoost4J-Spark (#4446)
* Automatically set maximize_evaluation_metrics if not explicitly given. * When custom_eval is set, require maximize_evaluation_metrics. * Update documents on early stop in XGBoost4J-Spark. * Fix code error.
This commit is contained in:
@@ -24,6 +24,7 @@ import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.commons.io.FileUtils
|
||||
@@ -157,13 +158,21 @@ object XGBoost extends Serializable {
|
||||
try {
|
||||
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
if (numEarlyStoppingRounds > 0) {
|
||||
if (!params.contains("maximize_evaluation_metrics")) {
|
||||
throw new IllegalArgumentException("maximize_evaluation_metrics has to be specified")
|
||||
val overridedParams = if (numEarlyStoppingRounds > 0 &&
|
||||
!params.contains("maximize_evaluation_metrics")) {
|
||||
if (params.contains("custom_eval")) {
|
||||
throw new IllegalArgumentException("maximize_evaluation_metrics has to be "
|
||||
+ "specified when custom_eval is set")
|
||||
}
|
||||
val eval_metric = params("eval_metric").toString
|
||||
val maximize = LearningTaskParams.evalMetricsToMaximize contains eval_metric
|
||||
logger.info("parameter \"maximize_evaluation_metrics\" is set to " + maximize)
|
||||
params + ("maximize_evaluation_metrics" -> maximize)
|
||||
} else {
|
||||
params
|
||||
}
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
||||
val booster = SXGBoost.train(watches.toMap("train"), params, round,
|
||||
val booster = SXGBoost.train(watches.toMap("train"), overridedParams, round,
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||
|
||||
@@ -111,6 +111,10 @@ private[spark] object LearningTaskParams {
|
||||
|
||||
val supportedObjectiveType = HashSet("regression", "classification")
|
||||
|
||||
val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss",
|
||||
"auc", "aucpr", "ndcg", "map", "gamma-deviance")
|
||||
val evalMetricsToMaximize = HashSet("auc", "aucpr", "ndcg", "map")
|
||||
|
||||
val evalMetricsToMinimize = HashSet("rmse", "mae", "logloss", "error", "merror",
|
||||
"mlogloss", "gamma-deviance")
|
||||
|
||||
val supportedEvalMetrics = evalMetricsToMaximize union evalMetricsToMinimize
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user