diff --git a/doc/jvm/xgboost4j_spark_tutorial.rst b/doc/jvm/xgboost4j_spark_tutorial.rst index 15822cc9e..d104e78c3 100644 --- a/doc/jvm/xgboost4j_spark_tutorial.rst +++ b/doc/jvm/xgboost4j_spark_tutorial.rst @@ -239,7 +239,7 @@ Early Stopping Early stopping is a feature to prevent the unnecessary training iterations. By specifying ``num_early_stopping_rounds`` or directly call ``setNumEarlyStoppingRounds`` over a XGBoostClassifier or XGBoostRegressor, we can define number of rounds if the evaluation metric going away from the best iteration and early stop training iterations. -In additional to ``num_early_stopping_rounds``, you also need to define ``maximize_evaluation_metrics`` or call ``setMaximizeEvaluationMetrics`` to specify whether you want to maximize or minimize the metrics in training. +When it comes to custom eval metrics, in additional to ``num_early_stopping_rounds``, you also need to define ``maximize_evaluation_metrics`` or call ``setMaximizeEvaluationMetrics`` to specify whether you want to maximize or minimize the metrics in training. For built-in eval metrics, XGBoost4J-Spark will automatically select the direction. For example, we need to maximize the evaluation metrics (set ``maximize_evaluation_metrics`` with true), and set ``num_early_stopping_rounds`` with 5. The evaluation metric of 10th iteration is the maximum one until now. In the following iterations, if there is no evaluation metric greater than the 10th iteration's (best one), the traning would be early stopped at 15th iteration. diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index dd25e79d2..74d1ac81c 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -24,6 +24,7 @@ import scala.util.Random import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.RabitTracker +import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import org.apache.commons.io.FileUtils @@ -157,13 +158,21 @@ object XGBoost extends Serializable { try { val numEarlyStoppingRounds = params.get("num_early_stopping_rounds") .map(_.toString.toInt).getOrElse(0) - if (numEarlyStoppingRounds > 0) { - if (!params.contains("maximize_evaluation_metrics")) { - throw new IllegalArgumentException("maximize_evaluation_metrics has to be specified") + val overridedParams = if (numEarlyStoppingRounds > 0 && + !params.contains("maximize_evaluation_metrics")) { + if (params.contains("custom_eval")) { + throw new IllegalArgumentException("maximize_evaluation_metrics has to be " + + "specified when custom_eval is set") } + val eval_metric = params("eval_metric").toString + val maximize = LearningTaskParams.evalMetricsToMaximize contains eval_metric + logger.info("parameter \"maximize_evaluation_metrics\" is set to " + maximize) + params + ("maximize_evaluation_metrics" -> maximize) + } else { + params } val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round)) - val booster = SXGBoost.train(watches.toMap("train"), params, round, + val booster = SXGBoost.train(watches.toMap("train"), overridedParams, round, watches.toMap, metrics, obj, eval, earlyStoppingRound = numEarlyStoppingRounds, prevBooster) Iterator(booster -> watches.toMap.keys.zip(metrics).toMap) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index a621305b0..f343aab69 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -111,6 +111,10 @@ private[spark] object LearningTaskParams { val supportedObjectiveType = HashSet("regression", "classification") - val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss", - "auc", "aucpr", "ndcg", "map", "gamma-deviance") + val evalMetricsToMaximize = HashSet("auc", "aucpr", "ndcg", "map") + + val evalMetricsToMinimize = HashSet("rmse", "mae", "logloss", "error", "merror", + "mlogloss", "gamma-deviance") + + val supportedEvalMetrics = evalMetricsToMaximize union evalMetricsToMinimize }