From 9c4ff50e83b23ad460d5130deb24ad673c8fe666 Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Sat, 24 Nov 2018 00:18:07 -0800 Subject: [PATCH] [jvm-packages]Fix early stopping condition (#3928) * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * update version * 0.82 * fix early stopping condition * remove unused * update comments * udpate comments * update test --- .../java/ml/dmlc/xgboost4j/java/XGBoost.java | 9 +- .../dmlc/xgboost4j/java/BoosterImplTest.java | 133 +++++++++++++++--- 2 files changed, 118 insertions(+), 24 deletions(-) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index 2fa162751..3980bcdf1 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -218,8 +218,15 @@ public class XGBoost { Map params, int earlyStoppingRounds, float[][] metrics, int iter) { boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params); boolean onTrack = false; + // we don't need to consider iterations before reaching to `earlyStoppingRounds`th iteration + if (iter < earlyStoppingRounds - 1) { + return true; + } float[] criterion = metrics[metrics.length - 1]; - for (int shift = 0; shift < Math.min(iter, earlyStoppingRounds) - 1; shift++) { + for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) { + // the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds` + // iterations goes to the expected direction, we should consider these `earlyStoppingRounds` + // as `onTrack` onTrack |= maximizeEvaluationMetrics ? criterion[iter - shift] >= criterion[iter - shift - 1] : criterion[iter - shift] <= criterion[iter - shift - 1]; diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index 5b2ecdcaf..e6fa7d709 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -139,7 +139,7 @@ public class BoosterImplTest { } private static class IncreasingEval implements IEvaluation { - private int value = 0; + private int value = 1; @Override public String getMetric() { @@ -152,6 +152,39 @@ public class BoosterImplTest { } } + @Test + public void testDescendMetricsWithBoundaryCondition() { + Map paramMap = new HashMap() { + { + put("max_depth", 3); + put("silent", 1); + put("objective", "binary:logistic"); + put("maximize_evaluation_metrics", "false"); + } + }; + int totalIterations = 10; + int earlyStoppingRounds = 10; + float[][] metrics = new float[1][totalIterations]; + for (int i = 0; i < totalIterations; i++) { + metrics[0][i] = i; + } + for (int itr = 0; itr < totalIterations; itr++) { + boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, + itr); + if (itr == totalIterations - 1) { + TestCase.assertFalse(onTrack); + for (int i = 0; i < totalIterations; i++) { + metrics[0][i] = totalIterations - i; + } + onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, + totalIterations - 1); + TestCase.assertTrue(onTrack); + } else { + TestCase.assertTrue(onTrack); + } + } + } + @Test public void testDescendMetrics() { Map paramMap = new HashMap() { @@ -162,26 +195,69 @@ public class BoosterImplTest { put("maximize_evaluation_metrics", "false"); } }; - float[][] metrics = new float[1][5]; - for (int i = 0; i < 5; i++) { + int totalIterations = 10; + int earlyStoppingRounds = 5; + float[][] metrics = new float[1][totalIterations]; + for (int i = 0; i < totalIterations; i++) { metrics[0][i] = i; } - boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, + totalIterations - 1); TestCase.assertFalse(onTrack); - for (int i = 0; i < 5; i++) { - metrics[0][i] = 5 - i; + for (int i = 0; i < totalIterations; i++) { + metrics[0][i] = totalIterations - i; } - onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, + totalIterations - 1); TestCase.assertTrue(onTrack); - for (int i = 0; i < 5; i++) { - metrics[0][i] = 5 - i; + for (int i = 0; i < totalIterations; i++) { + metrics[0][i] = totalIterations - i; } - metrics[0][0] = 1; - metrics[0][2] = 5; - onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + metrics[0][5] = 1; + metrics[0][6] = 2; + metrics[0][7] = 3; + metrics[0][8] = 4; + metrics[0][9] = 1; + onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, + totalIterations - 1); TestCase.assertTrue(onTrack); } + @Test + public void testAscendMetricsWithBoundaryCondition() { + Map paramMap = new HashMap() { + { + put("max_depth", 3); + put("silent", 1); + put("objective", "binary:logistic"); + put("maximize_evaluation_metrics", "true"); + } + }; + int totalIterations = 10; + int earlyStoppingRounds = 10; + float[][] metrics = new float[1][totalIterations]; + for (int iter = 0; iter < totalIterations; iter++) { + if (iter == totalIterations - 1) { + for (int i = 0; i < totalIterations; i++) { + metrics[0][i] = i; + } + boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter); + TestCase.assertTrue(onTrack); + for (int i = 0; i < totalIterations; i++) { + metrics[0][i] = totalIterations - i; + } + onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter); + TestCase.assertFalse(onTrack); + } else { + for (int i = 0; i < totalIterations; i++) { + metrics[0][i] = totalIterations - i; + } + boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter); + TestCase.assertTrue(onTrack); + } + } + } + @Test public void testAscendMetrics() { Map paramMap = new HashMap() { @@ -192,23 +268,28 @@ public class BoosterImplTest { put("maximize_evaluation_metrics", "true"); } }; - float[][] metrics = new float[1][5]; - for (int i = 0; i < 5; i++) { + int totalIterations = 10; + int earlyStoppingRounds = 5; + float[][] metrics = new float[1][totalIterations]; + for (int i = 0; i < totalIterations; i++) { metrics[0][i] = i; } - boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1); TestCase.assertTrue(onTrack); - for (int i = 0; i < 5; i++) { - metrics[0][i] = 5 - i; + for (int i = 0; i < totalIterations; i++) { + metrics[0][i] = totalIterations - i; } - onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1); TestCase.assertFalse(onTrack); - for (int i = 0; i < 5; i++) { + for (int i = 0; i < totalIterations; i++) { metrics[0][i] = i; } - metrics[0][0] = 6; - metrics[0][2] = 1; - onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); + metrics[0][5] = 9; + metrics[0][6] = 8; + metrics[0][7] = 7; + metrics[0][8] = 6; + metrics[0][9] = 9; + onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1); TestCase.assertTrue(onTrack); } @@ -237,7 +318,13 @@ public class BoosterImplTest { // Make sure we've stopped early. for (int w = 0; w < watches.size(); w++) { - for (int r = earlyStoppingRound + 1; r < round; r++) { + for (int r = 0; r < earlyStoppingRound; r++) { + TestCase.assertFalse(0.0f == metrics[w][r]); + } + } + + for (int w = 0; w < watches.size(); w++) { + for (int r = earlyStoppingRound; r < round; r++) { TestCase.assertEquals(0.0f, metrics[w][r]); } }