[jvm-packages] Fix early stop with xgboost4j-spark (#4176)

* Fix early stop with xgboost4j-spark

* Update XGBoost.java

* Update XGBoost.java

* Update XGBoost.java

To use -Float.MAX_VALUE as the lower bound, in case there is positive metric.

* Only update best score if the current score is better (no update when equal)

* Update xgboost-spark tutorial to fix early stopping docs.
This commit is contained in:
Yanbo Liang
2019-03-01 13:02:57 -08:00
committed by Nan Zhu
parent 7ea5675679
commit 9fefa2128d
3 changed files with 113 additions and 150 deletions

View File

@@ -154,188 +154,159 @@ public class BoosterImplTest {
@Test
public void testDescendMetricsWithBoundaryCondition() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "false");
}
};
int totalIterations = 10;
int earlyStoppingRounds = 10;
// maximize_evaluation_metrics = false
int totalIterations = 11;
int earlyStoppingRound = 10;
float[][] metrics = new float[1][totalIterations];
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
int bestIteration = 0;
for (int itr = 0; itr < totalIterations; itr++) {
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
itr);
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, itr, bestIteration);
if (itr == totalIterations - 1) {
TestCase.assertFalse(onTrack);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertTrue(onTrack);
TestCase.assertTrue(es);
} else {
TestCase.assertTrue(onTrack);
TestCase.assertFalse(es);
}
}
}
@Test
public void testEarlyStoppingForMultipleMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "true");
}
};
// maximize_evaluation_metrics = true
int earlyStoppingRound = 3;
int totalIterations = 5;
int numOfMetrics = 3;
float[][] metrics = new float[numOfMetrics][totalIterations];
// Only assign metric values to the first dataset, zeros for other datasets
for (int i = 0; i < numOfMetrics; i++) {
for (int j = 0; j < totalIterations; j++) {
metrics[0][j] = j;
}
}
int bestIteration;
for (int i = 0; i < totalIterations; i++) {
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
TestCase.assertTrue(onTrack);
bestIteration = i;
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
TestCase.assertFalse(es);
}
// when we have multiple datasets, only the last one was used to determinate early stop
// Here we changed the metric of the first dataset, it doesn't have any effect to the final result
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
// when we have multiple datasets, the training metrics is not considered
for (int i = 0; i < totalIterations; i++) {
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
TestCase.assertTrue(onTrack);
bestIteration = i;
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
TestCase.assertFalse(es);
}
// Now assign metric values to the last dataset.
for (int i = 0; i < totalIterations; i++) {
metrics[1][i] = totalIterations - i;
metrics[2][i] = totalIterations - i;
}
bestIteration = 0;
for (int i = 0; i < totalIterations; i++) {
// if any metrics off, we need to stop
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
if (i >= earlyStoppingRound - 1) {
TestCase.assertFalse(onTrack);
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
if (i >= earlyStoppingRound) {
TestCase.assertTrue(es);
} else {
TestCase.assertTrue(onTrack);
TestCase.assertFalse(es);
}
}
}
@Test
public void testDescendMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "false");
}
};
// maximize_evaluation_metrics = false
int totalIterations = 10;
int earlyStoppingRounds = 5;
float[][] metrics = new float[1][totalIterations];
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertFalse(onTrack);
int bestIteration = 0;
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
TestCase.assertTrue(es);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertTrue(onTrack);
bestIteration = totalIterations - 1;
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
TestCase.assertFalse(es);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
metrics[0][5] = 1;
metrics[0][6] = 2;
metrics[0][7] = 3;
metrics[0][8] = 4;
metrics[0][9] = 1;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertTrue(onTrack);
metrics[0][4] = 1;
metrics[0][9] = 5;
bestIteration = 4;
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
TestCase.assertTrue(es);
}
@Test
public void testAscendMetricsWithBoundaryCondition() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "true");
}
};
int totalIterations = 10;
// maximize_evaluation_metrics = true
int totalIterations = 11;
int earlyStoppingRounds = 10;
float[][] metrics = new float[1][totalIterations];
for (int iter = 0; iter < totalIterations; iter++) {
if (iter == totalIterations - 1) {
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
TestCase.assertTrue(onTrack);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
TestCase.assertFalse(onTrack);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
int bestIteration = 0;
for (int itr = 0; itr < totalIterations; itr++) {
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, itr, bestIteration);
if (itr == totalIterations - 1) {
TestCase.assertTrue(es);
} else {
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
TestCase.assertTrue(onTrack);
TestCase.assertFalse(es);
}
}
}
@Test
public void testAscendMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "true");
}
};
// maximize_evaluation_metrics = true
int totalIterations = 10;
int earlyStoppingRounds = 5;
float[][] metrics = new float[1][totalIterations];
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
TestCase.assertTrue(onTrack);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
TestCase.assertFalse(onTrack);
int bestIteration = 0;
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
TestCase.assertTrue(es);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
metrics[0][5] = 9;
metrics[0][6] = 8;
metrics[0][7] = 7;
metrics[0][8] = 6;
metrics[0][9] = 9;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
TestCase.assertTrue(onTrack);
bestIteration = totalIterations - 1;
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
TestCase.assertFalse(es);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
metrics[0][4] = 9;
metrics[0][9] = 4;
bestIteration = 4;
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
TestCase.assertTrue(es);
}
@Test
@@ -362,13 +333,13 @@ public class BoosterImplTest {
// Make sure we've stopped early.
for (int w = 0; w < watches.size(); w++) {
for (int r = 0; r < earlyStoppingRound; r++) {
for (int r = 0; r <= earlyStoppingRound; r++) {
TestCase.assertFalse(0.0f == metrics[w][r]);
}
}
for (int w = 0; w < watches.size(); w++) {
for (int r = earlyStoppingRound; r < round; r++) {
for (int r = earlyStoppingRound + 1; r < round; r++) {
TestCase.assertEquals(0.0f, metrics[w][r]);
}
}