[jvm-packages]Fix early stopping condition (#3928)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* update version

* 0.82

* fix early stopping condition

* remove unused

* update comments

* udpate comments

* update test
This commit is contained in:
Nan Zhu 2018-11-24 00:18:07 -08:00 committed by GitHub
parent 42cac4a30b
commit 9c4ff50e83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 118 additions and 24 deletions

View File

@ -218,8 +218,15 @@ public class XGBoost {
Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) { Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) {
boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params); boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params);
boolean onTrack = false; boolean onTrack = false;
// we don't need to consider iterations before reaching to `earlyStoppingRounds`th iteration
if (iter < earlyStoppingRounds - 1) {
return true;
}
float[] criterion = metrics[metrics.length - 1]; float[] criterion = metrics[metrics.length - 1];
for (int shift = 0; shift < Math.min(iter, earlyStoppingRounds) - 1; shift++) { for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
// the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds`
// iterations goes to the expected direction, we should consider these `earlyStoppingRounds`
// as `onTrack`
onTrack |= maximizeEvaluationMetrics ? onTrack |= maximizeEvaluationMetrics ?
criterion[iter - shift] >= criterion[iter - shift - 1] : criterion[iter - shift] >= criterion[iter - shift - 1] :
criterion[iter - shift] <= criterion[iter - shift - 1]; criterion[iter - shift] <= criterion[iter - shift - 1];

View File

@ -139,7 +139,7 @@ public class BoosterImplTest {
} }
private static class IncreasingEval implements IEvaluation { private static class IncreasingEval implements IEvaluation {
private int value = 0; private int value = 1;
@Override @Override
public String getMetric() { public String getMetric() {
@ -152,6 +152,39 @@ public class BoosterImplTest {
} }
} }
@Test
public void testDescendMetricsWithBoundaryCondition() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "false");
}
};
int totalIterations = 10;
int earlyStoppingRounds = 10;
float[][] metrics = new float[1][totalIterations];
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
for (int itr = 0; itr < totalIterations; itr++) {
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
itr);
if (itr == totalIterations - 1) {
TestCase.assertFalse(onTrack);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertTrue(onTrack);
} else {
TestCase.assertTrue(onTrack);
}
}
}
@Test @Test
public void testDescendMetrics() { public void testDescendMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() { Map<String, Object> paramMap = new HashMap<String, Object>() {
@ -162,26 +195,69 @@ public class BoosterImplTest {
put("maximize_evaluation_metrics", "false"); put("maximize_evaluation_metrics", "false");
} }
}; };
float[][] metrics = new float[1][5]; int totalIterations = 10;
for (int i = 0; i < 5; i++) { int earlyStoppingRounds = 5;
float[][] metrics = new float[1][totalIterations];
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i; metrics[0][i] = i;
} }
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertFalse(onTrack); TestCase.assertFalse(onTrack);
for (int i = 0; i < 5; i++) { for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = 5 - i; metrics[0][i] = totalIterations - i;
} }
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertTrue(onTrack); TestCase.assertTrue(onTrack);
for (int i = 0; i < 5; i++) { for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = 5 - i; metrics[0][i] = totalIterations - i;
} }
metrics[0][0] = 1; metrics[0][5] = 1;
metrics[0][2] = 5; metrics[0][6] = 2;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); metrics[0][7] = 3;
metrics[0][8] = 4;
metrics[0][9] = 1;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertTrue(onTrack); TestCase.assertTrue(onTrack);
} }
@Test
public void testAscendMetricsWithBoundaryCondition() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "true");
}
};
int totalIterations = 10;
int earlyStoppingRounds = 10;
float[][] metrics = new float[1][totalIterations];
for (int iter = 0; iter < totalIterations; iter++) {
if (iter == totalIterations - 1) {
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
TestCase.assertTrue(onTrack);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
TestCase.assertFalse(onTrack);
} else {
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
TestCase.assertTrue(onTrack);
}
}
}
@Test @Test
public void testAscendMetrics() { public void testAscendMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() { Map<String, Object> paramMap = new HashMap<String, Object>() {
@ -192,23 +268,28 @@ public class BoosterImplTest {
put("maximize_evaluation_metrics", "true"); put("maximize_evaluation_metrics", "true");
} }
}; };
float[][] metrics = new float[1][5]; int totalIterations = 10;
for (int i = 0; i < 5; i++) { int earlyStoppingRounds = 5;
float[][] metrics = new float[1][totalIterations];
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i; metrics[0][i] = i;
} }
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
TestCase.assertTrue(onTrack); TestCase.assertTrue(onTrack);
for (int i = 0; i < 5; i++) { for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = 5 - i; metrics[0][i] = totalIterations - i;
} }
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
TestCase.assertFalse(onTrack); TestCase.assertFalse(onTrack);
for (int i = 0; i < 5; i++) { for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i; metrics[0][i] = i;
} }
metrics[0][0] = 6; metrics[0][5] = 9;
metrics[0][2] = 1; metrics[0][6] = 8;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4); metrics[0][7] = 7;
metrics[0][8] = 6;
metrics[0][9] = 9;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
TestCase.assertTrue(onTrack); TestCase.assertTrue(onTrack);
} }
@ -237,7 +318,13 @@ public class BoosterImplTest {
// Make sure we've stopped early. // Make sure we've stopped early.
for (int w = 0; w < watches.size(); w++) { for (int w = 0; w < watches.size(); w++) {
for (int r = earlyStoppingRound + 1; r < round; r++) { for (int r = 0; r < earlyStoppingRound; r++) {
TestCase.assertFalse(0.0f == metrics[w][r]);
}
}
for (int w = 0; w < watches.size(); w++) {
for (int r = earlyStoppingRound; r < round; r++) {
TestCase.assertEquals(0.0f, metrics[w][r]); TestCase.assertEquals(0.0f, metrics[w][r]);
} }
} }