From 4ae225a08d3ff75d8d3a9f59ca477680cad2579c Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Tue, 23 Oct 2018 14:53:13 -0700 Subject: [PATCH] [Blocking][jvm-packages] fix the early stopping feature (#3808) * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * temp * add method for classifier and regressor * update tutorial * address the comments * update --- doc/jvm/xgboost4j_spark_tutorial.rst | 9 +++ .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 5 ++ .../scala/spark/XGBoostClassifier.scala | 3 + .../scala/spark/XGBoostRegressor.scala | 3 + .../spark/params/LearningTaskParams.scala | 7 +++ .../java/ml/dmlc/xgboost4j/java/XGBoost.java | 60 +++++++++++++----- .../dmlc/xgboost4j/java/BoosterImplTest.java | 61 +++++++++++++++++++ 7 files changed, 134 insertions(+), 14 deletions(-) diff --git a/doc/jvm/xgboost4j_spark_tutorial.rst b/doc/jvm/xgboost4j_spark_tutorial.rst index 4e6fb9d9d..d72d0a6a8 100644 --- a/doc/jvm/xgboost4j_spark_tutorial.rst +++ b/doc/jvm/xgboost4j_spark_tutorial.rst @@ -183,6 +183,15 @@ After we set XGBoostClassifier parameters and feature/label column, we can build val xgbClassificationModel = xgbClassifier.fit(xgbInput) +Early Stopping +---------------- + +Early stopping is a feature to prevent the unnecessary training iterations. By specifying ``num_early_stopping_rounds`` or directly call ``setNumEarlyStoppingRounds`` over a XGBoostClassifier or XGBoostRegressor, we can define number of rounds for the evaluation metric going to the unexpected direction to tolerate before stopping the training. + +In additional to ``num_early_stopping_rounds``, you also need to define ``maximize_evaluation_metrics`` or call ``setMaximizeEvaluationMetrics`` to specify whether you want to maximize or minimize the metrics in training. + +After specifying these two parameters, the training would stop when the metrics goes to the other direction against the one specified by ``maximize_evaluation_metrics`` for ``num_early_stopping_rounds`` iterations. + Prediction ========== diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 4177af88a..bfacd765f 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -132,6 +132,11 @@ object XGBoost extends Serializable { try { val numEarlyStoppingRounds = params.get("num_early_stopping_rounds") .map(_.toString.toInt).getOrElse(0) + if (numEarlyStoppingRounds > 0) { + if (!params.contains("maximize_evaluation_metrics")) { + throw new IllegalArgumentException("maximize_evaluation_metrics has to be specified") + } + } val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round)) val booster = SXGBoost.train(watches.train, params, round, watches.toMap, metrics, obj, eval, diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 869c1fe9c..0206f67dd 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -140,6 +140,9 @@ class XGBoostClassifier ( def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value) + def setMaximizeEvaluationMetrics(value: Boolean): this.type = + set(maximizeEvaluationMetrics, value) + def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value) def setCustomEval(value: EvalTrait): this.type = set(customEval, value) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index 5ce659bb0..9ac6ab9a4 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -140,6 +140,9 @@ class XGBoostRegressor ( def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value) + def setMaximizeEvaluationMetrics(value: Boolean): this.type = + set(maximizeEvaluationMetrics, value) + def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value) def setCustomEval(value: EvalTrait): this.type = set(customEval, value) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 5d3106a02..804e9a387 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -87,6 +87,13 @@ private[spark] trait LearningTaskParams extends Params { final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds) + + final val maximizeEvaluationMetrics = new BooleanParam(this, "maximizeEvaluationMetrics", + "define the expected optimization to the evaluation metrics, true to maximize otherwise" + + " minimize it") + + final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics) + setDefault(objective -> "reg:linear", baseScore -> 0.5, trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0) } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index df030105d..2fa162751 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -118,9 +118,9 @@ public class XGBoost { * performance on the validation set. * @param metrics array containing the evaluation metrics for each matrix in watches for each * iteration - * @param earlyStoppingRound if non-zero, training would be stopped + * @param earlyStoppingRounds if non-zero, training would be stopped * after a specified number of consecutive - * increases in any evaluation metric. + * goes to the unexpected direction in any evaluation metric. * @param obj customized objective * @param eval customized evaluation * @param booster train from scratch if set to null; train from an existing booster if not null. @@ -134,7 +134,7 @@ public class XGBoost { float[][] metrics, IObjective obj, IEvaluation eval, - int earlyStoppingRound, + int earlyStoppingRounds, Booster booster) throws XGBoostError { //collect eval matrixs @@ -196,17 +196,14 @@ public class XGBoost { for (int i = 0; i < metricsOut.length; i++) { metrics[i][iter] = metricsOut[i]; } - - boolean decreasing = true; - float[] criterion = metrics[metrics.length - 1]; - for (int shift = 0; shift < Math.min(iter, earlyStoppingRound) - 1; shift++) { - decreasing &= criterion[iter - shift] <= criterion[iter - shift - 1]; - } - - if (!decreasing) { - Rabit.trackerPrint(String.format( - "early stopping after %d decreasing rounds", earlyStoppingRound)); - break; + if (earlyStoppingRounds > 0) { + boolean onTrack = judgeIfTrainingOnTrack(params, earlyStoppingRounds, metrics, iter); + if (!onTrack) { + String reversedDirection = getReversedDirection(params); + Rabit.trackerPrint(String.format( + "early stopping after %d %s rounds", earlyStoppingRounds, reversedDirection)); + break; + } } if (Rabit.getRank() == 0) { Rabit.trackerPrint(evalInfo + '\n'); @@ -217,6 +214,41 @@ public class XGBoost { return booster; } + static boolean judgeIfTrainingOnTrack( + Map params, int earlyStoppingRounds, float[][] metrics, int iter) { + boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params); + boolean onTrack = false; + float[] criterion = metrics[metrics.length - 1]; + for (int shift = 0; shift < Math.min(iter, earlyStoppingRounds) - 1; shift++) { + onTrack |= maximizeEvaluationMetrics ? + criterion[iter - shift] >= criterion[iter - shift - 1] : + criterion[iter - shift] <= criterion[iter - shift - 1]; + } + return onTrack; + } + + private static String getReversedDirection(Map params) { + String reversedDirection = null; + if (Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) { + reversedDirection = "descending"; + } else if (!Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) { + reversedDirection = "ascending"; + } + return reversedDirection; + } + + private static boolean getMetricsExpectedDirection(Map params) { + try { + String maximize = String.valueOf(params.get("maximize_evaluation_metrics")); + assert(maximize != null); + return Boolean.valueOf(maximize); + } catch (Exception ex) { + logger.error("maximize_evaluation_metrics has to be specified for enabling early stop," + + " allowed value: true/false", ex); + throw ex; + } + } + /** * Cross-validation with given parameters. * diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index 85d2b61d2..5b2ecdcaf 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -152,6 +152,66 @@ public class BoosterImplTest { } } + @Test + public void testDescendMetrics() { + Map paramMap = new HashMap() { + { + put("max_depth", 3); + put("silent", 1); + put("objective", "binary:logistic"); + put("maximize_evaluation_metrics", "false"); + } + }; + float[][] metrics = new float[1][5]; + for (int i = 0; i < 5; i++) { + metrics[0][i] = i; + } + boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + TestCase.assertFalse(onTrack); + for (int i = 0; i < 5; i++) { + metrics[0][i] = 5 - i; + } + onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + TestCase.assertTrue(onTrack); + for (int i = 0; i < 5; i++) { + metrics[0][i] = 5 - i; + } + metrics[0][0] = 1; + metrics[0][2] = 5; + onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + TestCase.assertTrue(onTrack); + } + + @Test + public void testAscendMetrics() { + Map paramMap = new HashMap() { + { + put("max_depth", 3); + put("silent", 1); + put("objective", "binary:logistic"); + put("maximize_evaluation_metrics", "true"); + } + }; + float[][] metrics = new float[1][5]; + for (int i = 0; i < 5; i++) { + metrics[0][i] = i; + } + boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + TestCase.assertTrue(onTrack); + for (int i = 0; i < 5; i++) { + metrics[0][i] = 5 - i; + } + onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + TestCase.assertFalse(onTrack); + for (int i = 0; i < 5; i++) { + metrics[0][i] = i; + } + metrics[0][0] = 6; + metrics[0][2] = 1; + onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + TestCase.assertTrue(onTrack); + } + @Test public void testBoosterEarlyStop() throws XGBoostError, IOException { DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); @@ -162,6 +222,7 @@ public class BoosterImplTest { put("max_depth", 3); put("silent", 1); put("objective", "binary:logistic"); + put("maximize_evaluation_metrics", "false"); } }; Map watches = new LinkedHashMap<>();