diff --git a/doc/jvm/xgboost4j_spark_tutorial.rst b/doc/jvm/xgboost4j_spark_tutorial.rst index 48f1e0d3d..caf645f06 100644 --- a/doc/jvm/xgboost4j_spark_tutorial.rst +++ b/doc/jvm/xgboost4j_spark_tutorial.rst @@ -194,11 +194,11 @@ After we set XGBoostClassifier parameters and feature/label column, we can build Early Stopping ---------------- -Early stopping is a feature to prevent the unnecessary training iterations. By specifying ``num_early_stopping_rounds`` or directly call ``setNumEarlyStoppingRounds`` over a XGBoostClassifier or XGBoostRegressor, we can define number of rounds for the evaluation metric going to the unexpected direction to tolerate before stopping the training. +Early stopping is a feature to prevent the unnecessary training iterations. By specifying ``num_early_stopping_rounds`` or directly call ``setNumEarlyStoppingRounds`` over a XGBoostClassifier or XGBoostRegressor, we can define number of rounds if the evaluation metric going away from the best iteration and early stop training iterations. In additional to ``num_early_stopping_rounds``, you also need to define ``maximize_evaluation_metrics`` or call ``setMaximizeEvaluationMetrics`` to specify whether you want to maximize or minimize the metrics in training. -After specifying these two parameters, the training would stop when the metrics goes to the other direction against the one specified by ``maximize_evaluation_metrics`` for ``num_early_stopping_rounds`` iterations. +For example, we need to maximize the evaluation metrics (set ``maximize_evaluation_metrics`` with true), and set ``num_early_stopping_rounds`` with 5. The evaluation metric of 10th iteration is the maximum one until now. In the following iterations, if there is no evaluation metric greater than the 10th iteration's (best one), the traning would be early stopped at 15th iteration. Training with Evaluation Sets ---------------- diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index f2bf19790..06ea2eb4b 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -140,6 +140,8 @@ public class XGBoost { //collect eval matrixs String[] evalNames; DMatrix[] evalMats; + float bestScore; + int bestIteration; List names = new ArrayList(); List mats = new ArrayList(); @@ -150,6 +152,12 @@ public class XGBoost { evalNames = names.toArray(new String[names.size()]); evalMats = mats.toArray(new DMatrix[mats.size()]); + if (isMaximizeEvaluation(params)) { + bestScore = -Float.MAX_VALUE; + } else { + bestScore = Float.MAX_VALUE; + } + bestIteration = 0; metrics = metrics == null ? new float[evalNames.length][round] : metrics; //collect all data matrixs @@ -196,12 +204,27 @@ public class XGBoost { for (int i = 0; i < metricsOut.length; i++) { metrics[i][iter] = metricsOut[i]; } + + // If there is more than one evaluation datasets, the last one would be used + // to determinate early stop. + float score = metricsOut[metricsOut.length - 1]; + if (isMaximizeEvaluation(params)) { + // Update best score if the current score is better (no update when equal) + if (score > bestScore) { + bestScore = score; + bestIteration = iter; + } + } else { + if (score < bestScore) { + bestScore = score; + bestIteration = iter; + } + } if (earlyStoppingRounds > 0) { - boolean onTrack = judgeIfTrainingOnTrack(params, earlyStoppingRounds, metrics, iter); - if (!onTrack) { - String reversedDirection = getReversedDirection(params); + if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) { Rabit.trackerPrint(String.format( - "early stopping after %d %s rounds", earlyStoppingRounds, reversedDirection)); + "early stopping after %d rounds away from the best iteration", + earlyStoppingRounds)); break; } } @@ -214,42 +237,11 @@ public class XGBoost { return booster; } - static boolean judgeIfTrainingOnTrack( - Map params, int earlyStoppingRounds, float[][] metrics, int iter) { - boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params); - boolean onTrack = false; - // we don't need to consider iterations before reaching to `earlyStoppingRounds`th iteration - if (iter < earlyStoppingRounds - 1) { - return true; - } - for (int metricsId = metrics.length == 1 ? 0 : 1; metricsId < metrics.length; metricsId++) { - float[] criterion = metrics[metricsId]; - for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) { - // the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds` - // iterations goes to the expected direction, we should consider these `earlyStoppingRounds` - // as `onTrack` - onTrack |= maximizeEvaluationMetrics ? - criterion[iter - shift] >= criterion[iter - shift - 1] : - criterion[iter - shift] <= criterion[iter - shift - 1]; - } - if (!onTrack) { - return false; - } - } - return onTrack; + static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIteration) { + return iter - bestIteration >= earlyStoppingRounds; } - private static String getReversedDirection(Map params) { - String reversedDirection = null; - if (Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) { - reversedDirection = "descending"; - } else if (!Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) { - reversedDirection = "ascending"; - } - return reversedDirection; - } - - private static boolean getMetricsExpectedDirection(Map params) { + private static boolean isMaximizeEvaluation(Map params) { try { String maximize = String.valueOf(params.get("maximize_evaluation_metrics")); assert(maximize != null); diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index f7b2ff8e3..c03174261 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -154,188 +154,159 @@ public class BoosterImplTest { @Test public void testDescendMetricsWithBoundaryCondition() { - Map paramMap = new HashMap() { - { - put("max_depth", 3); - put("silent", 1); - put("objective", "binary:logistic"); - put("maximize_evaluation_metrics", "false"); - } - }; - int totalIterations = 10; - int earlyStoppingRounds = 10; + // maximize_evaluation_metrics = false + int totalIterations = 11; + int earlyStoppingRound = 10; float[][] metrics = new float[1][totalIterations]; for (int i = 0; i < totalIterations; i++) { metrics[0][i] = i; } + int bestIteration = 0; + for (int itr = 0; itr < totalIterations; itr++) { - boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, - itr); + boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, itr, bestIteration); if (itr == totalIterations - 1) { - TestCase.assertFalse(onTrack); - for (int i = 0; i < totalIterations; i++) { - metrics[0][i] = totalIterations - i; - } - onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, - totalIterations - 1); - TestCase.assertTrue(onTrack); + TestCase.assertTrue(es); } else { - TestCase.assertTrue(onTrack); + TestCase.assertFalse(es); } } } @Test public void testEarlyStoppingForMultipleMetrics() { - Map paramMap = new HashMap() { - { - put("max_depth", 3); - put("silent", 1); - put("objective", "binary:logistic"); - put("maximize_evaluation_metrics", "true"); - } - }; + // maximize_evaluation_metrics = true int earlyStoppingRound = 3; int totalIterations = 5; int numOfMetrics = 3; float[][] metrics = new float[numOfMetrics][totalIterations]; + // Only assign metric values to the first dataset, zeros for other datasets for (int i = 0; i < numOfMetrics; i++) { for (int j = 0; j < totalIterations; j++) { metrics[0][j] = j; } } + int bestIteration; + for (int i = 0; i < totalIterations; i++) { - boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i); - TestCase.assertTrue(onTrack); + bestIteration = i; + boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration); + TestCase.assertFalse(es); } + + // when we have multiple datasets, only the last one was used to determinate early stop + // Here we changed the metric of the first dataset, it doesn't have any effect to the final result for (int i = 0; i < totalIterations; i++) { metrics[0][i] = totalIterations - i; } - // when we have multiple datasets, the training metrics is not considered for (int i = 0; i < totalIterations; i++) { - boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i); - TestCase.assertTrue(onTrack); + bestIteration = i; + boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration); + TestCase.assertFalse(es); } + + // Now assign metric values to the last dataset. for (int i = 0; i < totalIterations; i++) { - metrics[1][i] = totalIterations - i; + metrics[2][i] = totalIterations - i; } + bestIteration = 0; + for (int i = 0; i < totalIterations; i++) { // if any metrics off, we need to stop - boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i); - if (i >= earlyStoppingRound - 1) { - TestCase.assertFalse(onTrack); + boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration); + if (i >= earlyStoppingRound) { + TestCase.assertTrue(es); } else { - TestCase.assertTrue(onTrack); + TestCase.assertFalse(es); } } } @Test public void testDescendMetrics() { - Map paramMap = new HashMap() { - { - put("max_depth", 3); - put("silent", 1); - put("objective", "binary:logistic"); - put("maximize_evaluation_metrics", "false"); - } - }; + // maximize_evaluation_metrics = false int totalIterations = 10; int earlyStoppingRounds = 5; float[][] metrics = new float[1][totalIterations]; for (int i = 0; i < totalIterations; i++) { metrics[0][i] = i; } - boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, - totalIterations - 1); - TestCase.assertFalse(onTrack); + int bestIteration = 0; + + boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration); + TestCase.assertTrue(es); for (int i = 0; i < totalIterations; i++) { metrics[0][i] = totalIterations - i; } - onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, - totalIterations - 1); - TestCase.assertTrue(onTrack); + bestIteration = totalIterations - 1; + + es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration); + TestCase.assertFalse(es); + for (int i = 0; i < totalIterations; i++) { metrics[0][i] = totalIterations - i; } - metrics[0][5] = 1; - metrics[0][6] = 2; - metrics[0][7] = 3; - metrics[0][8] = 4; - metrics[0][9] = 1; - onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, - totalIterations - 1); - TestCase.assertTrue(onTrack); + metrics[0][4] = 1; + metrics[0][9] = 5; + + bestIteration = 4; + + es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration); + TestCase.assertTrue(es); } @Test public void testAscendMetricsWithBoundaryCondition() { - Map paramMap = new HashMap() { - { - put("max_depth", 3); - put("silent", 1); - put("objective", "binary:logistic"); - put("maximize_evaluation_metrics", "true"); - } - }; - int totalIterations = 10; + // maximize_evaluation_metrics = true + int totalIterations = 11; int earlyStoppingRounds = 10; float[][] metrics = new float[1][totalIterations]; - for (int iter = 0; iter < totalIterations; iter++) { - if (iter == totalIterations - 1) { - for (int i = 0; i < totalIterations; i++) { - metrics[0][i] = i; - } - boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter); - TestCase.assertTrue(onTrack); - for (int i = 0; i < totalIterations; i++) { - metrics[0][i] = totalIterations - i; - } - onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter); - TestCase.assertFalse(onTrack); + for (int i = 0; i < totalIterations; i++) { + metrics[0][i] = totalIterations - i; + } + int bestIteration = 0; + + for (int itr = 0; itr < totalIterations; itr++) { + boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, itr, bestIteration); + if (itr == totalIterations - 1) { + TestCase.assertTrue(es); } else { - for (int i = 0; i < totalIterations; i++) { - metrics[0][i] = totalIterations - i; - } - boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter); - TestCase.assertTrue(onTrack); + TestCase.assertFalse(es); } } } @Test public void testAscendMetrics() { - Map paramMap = new HashMap() { - { - put("max_depth", 3); - put("silent", 1); - put("objective", "binary:logistic"); - put("maximize_evaluation_metrics", "true"); - } - }; + // maximize_evaluation_metrics = true int totalIterations = 10; int earlyStoppingRounds = 5; float[][] metrics = new float[1][totalIterations]; - for (int i = 0; i < totalIterations; i++) { - metrics[0][i] = i; - } - boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1); - TestCase.assertTrue(onTrack); for (int i = 0; i < totalIterations; i++) { metrics[0][i] = totalIterations - i; } - onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1); - TestCase.assertFalse(onTrack); + int bestIteration = 0; + + boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration); + TestCase.assertTrue(es); for (int i = 0; i < totalIterations; i++) { metrics[0][i] = i; } - metrics[0][5] = 9; - metrics[0][6] = 8; - metrics[0][7] = 7; - metrics[0][8] = 6; - metrics[0][9] = 9; - onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1); - TestCase.assertTrue(onTrack); + bestIteration = totalIterations - 1; + + es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration); + TestCase.assertFalse(es); + + for (int i = 0; i < totalIterations; i++) { + metrics[0][i] = i; + } + metrics[0][4] = 9; + metrics[0][9] = 4; + + bestIteration = 4; + + es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration); + TestCase.assertTrue(es); } @Test @@ -362,13 +333,13 @@ public class BoosterImplTest { // Make sure we've stopped early. for (int w = 0; w < watches.size(); w++) { - for (int r = 0; r < earlyStoppingRound; r++) { + for (int r = 0; r <= earlyStoppingRound; r++) { TestCase.assertFalse(0.0f == metrics[w][r]); } } for (int w = 0; w < watches.size(); w++) { - for (int r = earlyStoppingRound; r < round; r++) { + for (int r = earlyStoppingRound + 1; r < round; r++) { TestCase.assertEquals(0.0f, metrics[w][r]); } }