[jvm-packages] Fix early stop with xgboost4j-spark (#4176)

* Fix early stop with xgboost4j-spark

* Update XGBoost.java

* Update XGBoost.java

* Update XGBoost.java

To use -Float.MAX_VALUE as the lower bound, in case there is positive metric.

* Only update best score if the current score is better (no update when equal)

* Update xgboost-spark tutorial to fix early stopping docs.
This commit is contained in:
Yanbo Liang 2019-03-01 13:02:57 -08:00 committed by Nan Zhu
parent 7ea5675679
commit 9fefa2128d
3 changed files with 113 additions and 150 deletions

View File

@ -194,11 +194,11 @@ After we set XGBoostClassifier parameters and feature/label column, we can build
Early Stopping
----------------
Early stopping is a feature to prevent the unnecessary training iterations. By specifying ``num_early_stopping_rounds`` or directly call ``setNumEarlyStoppingRounds`` over a XGBoostClassifier or XGBoostRegressor, we can define number of rounds for the evaluation metric going to the unexpected direction to tolerate before stopping the training.
Early stopping is a feature to prevent the unnecessary training iterations. By specifying ``num_early_stopping_rounds`` or directly call ``setNumEarlyStoppingRounds`` over a XGBoostClassifier or XGBoostRegressor, we can define number of rounds if the evaluation metric going away from the best iteration and early stop training iterations.
In additional to ``num_early_stopping_rounds``, you also need to define ``maximize_evaluation_metrics`` or call ``setMaximizeEvaluationMetrics`` to specify whether you want to maximize or minimize the metrics in training.
After specifying these two parameters, the training would stop when the metrics goes to the other direction against the one specified by ``maximize_evaluation_metrics`` for ``num_early_stopping_rounds`` iterations.
For example, we need to maximize the evaluation metrics (set ``maximize_evaluation_metrics`` with true), and set ``num_early_stopping_rounds`` with 5. The evaluation metric of 10th iteration is the maximum one until now. In the following iterations, if there is no evaluation metric greater than the 10th iteration's (best one), the traning would be early stopped at 15th iteration.
Training with Evaluation Sets
----------------

View File

@ -140,6 +140,8 @@ public class XGBoost {
//collect eval matrixs
String[] evalNames;
DMatrix[] evalMats;
float bestScore;
int bestIteration;
List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>();
@ -150,6 +152,12 @@ public class XGBoost {
evalNames = names.toArray(new String[names.size()]);
evalMats = mats.toArray(new DMatrix[mats.size()]);
if (isMaximizeEvaluation(params)) {
bestScore = -Float.MAX_VALUE;
} else {
bestScore = Float.MAX_VALUE;
}
bestIteration = 0;
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
//collect all data matrixs
@ -196,12 +204,27 @@ public class XGBoost {
for (int i = 0; i < metricsOut.length; i++) {
metrics[i][iter] = metricsOut[i];
}
// If there is more than one evaluation datasets, the last one would be used
// to determinate early stop.
float score = metricsOut[metricsOut.length - 1];
if (isMaximizeEvaluation(params)) {
// Update best score if the current score is better (no update when equal)
if (score > bestScore) {
bestScore = score;
bestIteration = iter;
}
} else {
if (score < bestScore) {
bestScore = score;
bestIteration = iter;
}
}
if (earlyStoppingRounds > 0) {
boolean onTrack = judgeIfTrainingOnTrack(params, earlyStoppingRounds, metrics, iter);
if (!onTrack) {
String reversedDirection = getReversedDirection(params);
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
Rabit.trackerPrint(String.format(
"early stopping after %d %s rounds", earlyStoppingRounds, reversedDirection));
"early stopping after %d rounds away from the best iteration",
earlyStoppingRounds));
break;
}
}
@ -214,42 +237,11 @@ public class XGBoost {
return booster;
}
static boolean judgeIfTrainingOnTrack(
Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) {
boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params);
boolean onTrack = false;
// we don't need to consider iterations before reaching to `earlyStoppingRounds`th iteration
if (iter < earlyStoppingRounds - 1) {
return true;
}
for (int metricsId = metrics.length == 1 ? 0 : 1; metricsId < metrics.length; metricsId++) {
float[] criterion = metrics[metricsId];
for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
// the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds`
// iterations goes to the expected direction, we should consider these `earlyStoppingRounds`
// as `onTrack`
onTrack |= maximizeEvaluationMetrics ?
criterion[iter - shift] >= criterion[iter - shift - 1] :
criterion[iter - shift] <= criterion[iter - shift - 1];
}
if (!onTrack) {
return false;
}
}
return onTrack;
static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIteration) {
return iter - bestIteration >= earlyStoppingRounds;
}
private static String getReversedDirection(Map<String, Object> params) {
String reversedDirection = null;
if (Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
reversedDirection = "descending";
} else if (!Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
reversedDirection = "ascending";
}
return reversedDirection;
}
private static boolean getMetricsExpectedDirection(Map<String, Object> params) {
private static boolean isMaximizeEvaluation(Map<String, Object> params) {
try {
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
assert(maximize != null);

View File

@ -154,188 +154,159 @@ public class BoosterImplTest {
@Test
public void testDescendMetricsWithBoundaryCondition() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "false");
}
};
int totalIterations = 10;
int earlyStoppingRounds = 10;
// maximize_evaluation_metrics = false
int totalIterations = 11;
int earlyStoppingRound = 10;
float[][] metrics = new float[1][totalIterations];
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
int bestIteration = 0;
for (int itr = 0; itr < totalIterations; itr++) {
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
itr);
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, itr, bestIteration);
if (itr == totalIterations - 1) {
TestCase.assertFalse(onTrack);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertTrue(onTrack);
TestCase.assertTrue(es);
} else {
TestCase.assertTrue(onTrack);
TestCase.assertFalse(es);
}
}
}
@Test
public void testEarlyStoppingForMultipleMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "true");
}
};
// maximize_evaluation_metrics = true
int earlyStoppingRound = 3;
int totalIterations = 5;
int numOfMetrics = 3;
float[][] metrics = new float[numOfMetrics][totalIterations];
// Only assign metric values to the first dataset, zeros for other datasets
for (int i = 0; i < numOfMetrics; i++) {
for (int j = 0; j < totalIterations; j++) {
metrics[0][j] = j;
}
}
int bestIteration;
for (int i = 0; i < totalIterations; i++) {
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
TestCase.assertTrue(onTrack);
bestIteration = i;
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
TestCase.assertFalse(es);
}
// when we have multiple datasets, only the last one was used to determinate early stop
// Here we changed the metric of the first dataset, it doesn't have any effect to the final result
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
// when we have multiple datasets, the training metrics is not considered
for (int i = 0; i < totalIterations; i++) {
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
TestCase.assertTrue(onTrack);
bestIteration = i;
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
TestCase.assertFalse(es);
}
// Now assign metric values to the last dataset.
for (int i = 0; i < totalIterations; i++) {
metrics[1][i] = totalIterations - i;
metrics[2][i] = totalIterations - i;
}
bestIteration = 0;
for (int i = 0; i < totalIterations; i++) {
// if any metrics off, we need to stop
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
if (i >= earlyStoppingRound - 1) {
TestCase.assertFalse(onTrack);
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
if (i >= earlyStoppingRound) {
TestCase.assertTrue(es);
} else {
TestCase.assertTrue(onTrack);
TestCase.assertFalse(es);
}
}
}
@Test
public void testDescendMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "false");
}
};
// maximize_evaluation_metrics = false
int totalIterations = 10;
int earlyStoppingRounds = 5;
float[][] metrics = new float[1][totalIterations];
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertFalse(onTrack);
int bestIteration = 0;
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
TestCase.assertTrue(es);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertTrue(onTrack);
bestIteration = totalIterations - 1;
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
TestCase.assertFalse(es);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
metrics[0][5] = 1;
metrics[0][6] = 2;
metrics[0][7] = 3;
metrics[0][8] = 4;
metrics[0][9] = 1;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
totalIterations - 1);
TestCase.assertTrue(onTrack);
metrics[0][4] = 1;
metrics[0][9] = 5;
bestIteration = 4;
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
TestCase.assertTrue(es);
}
@Test
public void testAscendMetricsWithBoundaryCondition() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "true");
}
};
int totalIterations = 10;
// maximize_evaluation_metrics = true
int totalIterations = 11;
int earlyStoppingRounds = 10;
float[][] metrics = new float[1][totalIterations];
for (int iter = 0; iter < totalIterations; iter++) {
if (iter == totalIterations - 1) {
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
TestCase.assertTrue(onTrack);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
TestCase.assertFalse(onTrack);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
int bestIteration = 0;
for (int itr = 0; itr < totalIterations; itr++) {
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, itr, bestIteration);
if (itr == totalIterations - 1) {
TestCase.assertTrue(es);
} else {
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
TestCase.assertTrue(onTrack);
TestCase.assertFalse(es);
}
}
}
@Test
public void testAscendMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "true");
}
};
// maximize_evaluation_metrics = true
int totalIterations = 10;
int earlyStoppingRounds = 5;
float[][] metrics = new float[1][totalIterations];
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
TestCase.assertTrue(onTrack);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = totalIterations - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
TestCase.assertFalse(onTrack);
int bestIteration = 0;
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
TestCase.assertTrue(es);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
metrics[0][5] = 9;
metrics[0][6] = 8;
metrics[0][7] = 7;
metrics[0][8] = 6;
metrics[0][9] = 9;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
TestCase.assertTrue(onTrack);
bestIteration = totalIterations - 1;
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
TestCase.assertFalse(es);
for (int i = 0; i < totalIterations; i++) {
metrics[0][i] = i;
}
metrics[0][4] = 9;
metrics[0][9] = 4;
bestIteration = 4;
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
TestCase.assertTrue(es);
}
@Test
@ -362,13 +333,13 @@ public class BoosterImplTest {
// Make sure we've stopped early.
for (int w = 0; w < watches.size(); w++) {
for (int r = 0; r < earlyStoppingRound; r++) {
for (int r = 0; r <= earlyStoppingRound; r++) {
TestCase.assertFalse(0.0f == metrics[w][r]);
}
}
for (int w = 0; w < watches.size(); w++) {
for (int r = earlyStoppingRound; r < round; r++) {
for (int r = earlyStoppingRound + 1; r < round; r++) {
TestCase.assertEquals(0.0f, metrics[w][r]);
}
}