[jvm-packages] Fix early stop with xgboost4j-spark (#4176)
* Fix early stop with xgboost4j-spark * Update XGBoost.java * Update XGBoost.java * Update XGBoost.java To use -Float.MAX_VALUE as the lower bound, in case there is positive metric. * Only update best score if the current score is better (no update when equal) * Update xgboost-spark tutorial to fix early stopping docs.
This commit is contained in:
parent
7ea5675679
commit
9fefa2128d
@ -194,11 +194,11 @@ After we set XGBoostClassifier parameters and feature/label column, we can build
|
||||
Early Stopping
|
||||
----------------
|
||||
|
||||
Early stopping is a feature to prevent the unnecessary training iterations. By specifying ``num_early_stopping_rounds`` or directly call ``setNumEarlyStoppingRounds`` over a XGBoostClassifier or XGBoostRegressor, we can define number of rounds for the evaluation metric going to the unexpected direction to tolerate before stopping the training.
|
||||
Early stopping is a feature to prevent the unnecessary training iterations. By specifying ``num_early_stopping_rounds`` or directly call ``setNumEarlyStoppingRounds`` over a XGBoostClassifier or XGBoostRegressor, we can define number of rounds if the evaluation metric going away from the best iteration and early stop training iterations.
|
||||
|
||||
In additional to ``num_early_stopping_rounds``, you also need to define ``maximize_evaluation_metrics`` or call ``setMaximizeEvaluationMetrics`` to specify whether you want to maximize or minimize the metrics in training.
|
||||
|
||||
After specifying these two parameters, the training would stop when the metrics goes to the other direction against the one specified by ``maximize_evaluation_metrics`` for ``num_early_stopping_rounds`` iterations.
|
||||
For example, we need to maximize the evaluation metrics (set ``maximize_evaluation_metrics`` with true), and set ``num_early_stopping_rounds`` with 5. The evaluation metric of 10th iteration is the maximum one until now. In the following iterations, if there is no evaluation metric greater than the 10th iteration's (best one), the traning would be early stopped at 15th iteration.
|
||||
|
||||
Training with Evaluation Sets
|
||||
----------------
|
||||
|
||||
@ -140,6 +140,8 @@ public class XGBoost {
|
||||
//collect eval matrixs
|
||||
String[] evalNames;
|
||||
DMatrix[] evalMats;
|
||||
float bestScore;
|
||||
int bestIteration;
|
||||
List<String> names = new ArrayList<String>();
|
||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||
|
||||
@ -150,6 +152,12 @@ public class XGBoost {
|
||||
|
||||
evalNames = names.toArray(new String[names.size()]);
|
||||
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
||||
if (isMaximizeEvaluation(params)) {
|
||||
bestScore = -Float.MAX_VALUE;
|
||||
} else {
|
||||
bestScore = Float.MAX_VALUE;
|
||||
}
|
||||
bestIteration = 0;
|
||||
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
|
||||
|
||||
//collect all data matrixs
|
||||
@ -196,12 +204,27 @@ public class XGBoost {
|
||||
for (int i = 0; i < metricsOut.length; i++) {
|
||||
metrics[i][iter] = metricsOut[i];
|
||||
}
|
||||
|
||||
// If there is more than one evaluation datasets, the last one would be used
|
||||
// to determinate early stop.
|
||||
float score = metricsOut[metricsOut.length - 1];
|
||||
if (isMaximizeEvaluation(params)) {
|
||||
// Update best score if the current score is better (no update when equal)
|
||||
if (score > bestScore) {
|
||||
bestScore = score;
|
||||
bestIteration = iter;
|
||||
}
|
||||
} else {
|
||||
if (score < bestScore) {
|
||||
bestScore = score;
|
||||
bestIteration = iter;
|
||||
}
|
||||
}
|
||||
if (earlyStoppingRounds > 0) {
|
||||
boolean onTrack = judgeIfTrainingOnTrack(params, earlyStoppingRounds, metrics, iter);
|
||||
if (!onTrack) {
|
||||
String reversedDirection = getReversedDirection(params);
|
||||
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
|
||||
Rabit.trackerPrint(String.format(
|
||||
"early stopping after %d %s rounds", earlyStoppingRounds, reversedDirection));
|
||||
"early stopping after %d rounds away from the best iteration",
|
||||
earlyStoppingRounds));
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -214,42 +237,11 @@ public class XGBoost {
|
||||
return booster;
|
||||
}
|
||||
|
||||
static boolean judgeIfTrainingOnTrack(
|
||||
Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) {
|
||||
boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params);
|
||||
boolean onTrack = false;
|
||||
// we don't need to consider iterations before reaching to `earlyStoppingRounds`th iteration
|
||||
if (iter < earlyStoppingRounds - 1) {
|
||||
return true;
|
||||
}
|
||||
for (int metricsId = metrics.length == 1 ? 0 : 1; metricsId < metrics.length; metricsId++) {
|
||||
float[] criterion = metrics[metricsId];
|
||||
for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
|
||||
// the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds`
|
||||
// iterations goes to the expected direction, we should consider these `earlyStoppingRounds`
|
||||
// as `onTrack`
|
||||
onTrack |= maximizeEvaluationMetrics ?
|
||||
criterion[iter - shift] >= criterion[iter - shift - 1] :
|
||||
criterion[iter - shift] <= criterion[iter - shift - 1];
|
||||
}
|
||||
if (!onTrack) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return onTrack;
|
||||
static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIteration) {
|
||||
return iter - bestIteration >= earlyStoppingRounds;
|
||||
}
|
||||
|
||||
private static String getReversedDirection(Map<String, Object> params) {
|
||||
String reversedDirection = null;
|
||||
if (Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
|
||||
reversedDirection = "descending";
|
||||
} else if (!Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
|
||||
reversedDirection = "ascending";
|
||||
}
|
||||
return reversedDirection;
|
||||
}
|
||||
|
||||
private static boolean getMetricsExpectedDirection(Map<String, Object> params) {
|
||||
private static boolean isMaximizeEvaluation(Map<String, Object> params) {
|
||||
try {
|
||||
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
|
||||
assert(maximize != null);
|
||||
|
||||
@ -154,188 +154,159 @@ public class BoosterImplTest {
|
||||
|
||||
@Test
|
||||
public void testDescendMetricsWithBoundaryCondition() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "false");
|
||||
}
|
||||
};
|
||||
int totalIterations = 10;
|
||||
int earlyStoppingRounds = 10;
|
||||
// maximize_evaluation_metrics = false
|
||||
int totalIterations = 11;
|
||||
int earlyStoppingRound = 10;
|
||||
float[][] metrics = new float[1][totalIterations];
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
int bestIteration = 0;
|
||||
|
||||
for (int itr = 0; itr < totalIterations; itr++) {
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||
itr);
|
||||
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, itr, bestIteration);
|
||||
if (itr == totalIterations - 1) {
|
||||
TestCase.assertFalse(onTrack);
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||
totalIterations - 1);
|
||||
TestCase.assertTrue(onTrack);
|
||||
TestCase.assertTrue(es);
|
||||
} else {
|
||||
TestCase.assertTrue(onTrack);
|
||||
TestCase.assertFalse(es);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEarlyStoppingForMultipleMetrics() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "true");
|
||||
}
|
||||
};
|
||||
// maximize_evaluation_metrics = true
|
||||
int earlyStoppingRound = 3;
|
||||
int totalIterations = 5;
|
||||
int numOfMetrics = 3;
|
||||
float[][] metrics = new float[numOfMetrics][totalIterations];
|
||||
// Only assign metric values to the first dataset, zeros for other datasets
|
||||
for (int i = 0; i < numOfMetrics; i++) {
|
||||
for (int j = 0; j < totalIterations; j++) {
|
||||
metrics[0][j] = j;
|
||||
}
|
||||
}
|
||||
int bestIteration;
|
||||
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
||||
TestCase.assertTrue(onTrack);
|
||||
bestIteration = i;
|
||||
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
|
||||
TestCase.assertFalse(es);
|
||||
}
|
||||
|
||||
// when we have multiple datasets, only the last one was used to determinate early stop
|
||||
// Here we changed the metric of the first dataset, it doesn't have any effect to the final result
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
// when we have multiple datasets, the training metrics is not considered
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
||||
TestCase.assertTrue(onTrack);
|
||||
bestIteration = i;
|
||||
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
|
||||
TestCase.assertFalse(es);
|
||||
}
|
||||
|
||||
// Now assign metric values to the last dataset.
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[1][i] = totalIterations - i;
|
||||
metrics[2][i] = totalIterations - i;
|
||||
}
|
||||
bestIteration = 0;
|
||||
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
// if any metrics off, we need to stop
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRound, metrics, i);
|
||||
if (i >= earlyStoppingRound - 1) {
|
||||
TestCase.assertFalse(onTrack);
|
||||
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRound, i, bestIteration);
|
||||
if (i >= earlyStoppingRound) {
|
||||
TestCase.assertTrue(es);
|
||||
} else {
|
||||
TestCase.assertTrue(onTrack);
|
||||
TestCase.assertFalse(es);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDescendMetrics() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "false");
|
||||
}
|
||||
};
|
||||
// maximize_evaluation_metrics = false
|
||||
int totalIterations = 10;
|
||||
int earlyStoppingRounds = 5;
|
||||
float[][] metrics = new float[1][totalIterations];
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||
totalIterations - 1);
|
||||
TestCase.assertFalse(onTrack);
|
||||
int bestIteration = 0;
|
||||
|
||||
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||
TestCase.assertTrue(es);
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||
totalIterations - 1);
|
||||
TestCase.assertTrue(onTrack);
|
||||
bestIteration = totalIterations - 1;
|
||||
|
||||
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||
TestCase.assertFalse(es);
|
||||
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
metrics[0][5] = 1;
|
||||
metrics[0][6] = 2;
|
||||
metrics[0][7] = 3;
|
||||
metrics[0][8] = 4;
|
||||
metrics[0][9] = 1;
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics,
|
||||
totalIterations - 1);
|
||||
TestCase.assertTrue(onTrack);
|
||||
metrics[0][4] = 1;
|
||||
metrics[0][9] = 5;
|
||||
|
||||
bestIteration = 4;
|
||||
|
||||
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||
TestCase.assertTrue(es);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAscendMetricsWithBoundaryCondition() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "true");
|
||||
}
|
||||
};
|
||||
int totalIterations = 10;
|
||||
// maximize_evaluation_metrics = true
|
||||
int totalIterations = 11;
|
||||
int earlyStoppingRounds = 10;
|
||||
float[][] metrics = new float[1][totalIterations];
|
||||
for (int iter = 0; iter < totalIterations; iter++) {
|
||||
if (iter == totalIterations - 1) {
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
|
||||
TestCase.assertTrue(onTrack);
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
|
||||
TestCase.assertFalse(onTrack);
|
||||
int bestIteration = 0;
|
||||
|
||||
for (int itr = 0; itr < totalIterations; itr++) {
|
||||
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, itr, bestIteration);
|
||||
if (itr == totalIterations - 1) {
|
||||
TestCase.assertTrue(es);
|
||||
} else {
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, iter);
|
||||
TestCase.assertTrue(onTrack);
|
||||
TestCase.assertFalse(es);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAscendMetrics() {
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("maximize_evaluation_metrics", "true");
|
||||
}
|
||||
};
|
||||
// maximize_evaluation_metrics = true
|
||||
int totalIterations = 10;
|
||||
int earlyStoppingRounds = 5;
|
||||
float[][] metrics = new float[1][totalIterations];
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
|
||||
TestCase.assertTrue(onTrack);
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = totalIterations - i;
|
||||
}
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
|
||||
TestCase.assertFalse(onTrack);
|
||||
int bestIteration = 0;
|
||||
|
||||
boolean es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||
TestCase.assertTrue(es);
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
metrics[0][5] = 9;
|
||||
metrics[0][6] = 8;
|
||||
metrics[0][7] = 7;
|
||||
metrics[0][8] = 6;
|
||||
metrics[0][9] = 9;
|
||||
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, earlyStoppingRounds, metrics, totalIterations - 1);
|
||||
TestCase.assertTrue(onTrack);
|
||||
bestIteration = totalIterations - 1;
|
||||
|
||||
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||
TestCase.assertFalse(es);
|
||||
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
metrics[0][i] = i;
|
||||
}
|
||||
metrics[0][4] = 9;
|
||||
metrics[0][9] = 4;
|
||||
|
||||
bestIteration = 4;
|
||||
|
||||
es = XGBoost.shouldEarlyStop(earlyStoppingRounds, totalIterations - 1, bestIteration);
|
||||
TestCase.assertTrue(es);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -362,13 +333,13 @@ public class BoosterImplTest {
|
||||
|
||||
// Make sure we've stopped early.
|
||||
for (int w = 0; w < watches.size(); w++) {
|
||||
for (int r = 0; r < earlyStoppingRound; r++) {
|
||||
for (int r = 0; r <= earlyStoppingRound; r++) {
|
||||
TestCase.assertFalse(0.0f == metrics[w][r]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int w = 0; w < watches.size(); w++) {
|
||||
for (int r = earlyStoppingRound; r < round; r++) {
|
||||
for (int r = earlyStoppingRound + 1; r < round; r++) {
|
||||
TestCase.assertEquals(0.0f, metrics[w][r]);
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user