[jvm-packages] Fix early stop with xgboost4j-spark (#4176)

* Fix early stop with xgboost4j-spark

* Update XGBoost.java

* Update XGBoost.java

* Update XGBoost.java

To use -Float.MAX_VALUE as the lower bound, in case there is positive metric.

* Only update best score if the current score is better (no update when equal)

* Update xgboost-spark tutorial to fix early stopping docs.
This commit is contained in:
Yanbo Liang
2019-03-01 13:02:57 -08:00
committed by Nan Zhu
parent 7ea5675679
commit 9fefa2128d
3 changed files with 113 additions and 150 deletions

View File

@@ -140,6 +140,8 @@ public class XGBoost {
//collect eval matrixs
String[] evalNames;
DMatrix[] evalMats;
float bestScore;
int bestIteration;
List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>();
@@ -150,6 +152,12 @@ public class XGBoost {
evalNames = names.toArray(new String[names.size()]);
evalMats = mats.toArray(new DMatrix[mats.size()]);
if (isMaximizeEvaluation(params)) {
bestScore = -Float.MAX_VALUE;
} else {
bestScore = Float.MAX_VALUE;
}
bestIteration = 0;
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
//collect all data matrixs
@@ -196,12 +204,27 @@ public class XGBoost {
for (int i = 0; i < metricsOut.length; i++) {
metrics[i][iter] = metricsOut[i];
}
// If there is more than one evaluation datasets, the last one would be used
// to determinate early stop.
float score = metricsOut[metricsOut.length - 1];
if (isMaximizeEvaluation(params)) {
// Update best score if the current score is better (no update when equal)
if (score > bestScore) {
bestScore = score;
bestIteration = iter;
}
} else {
if (score < bestScore) {
bestScore = score;
bestIteration = iter;
}
}
if (earlyStoppingRounds > 0) {
boolean onTrack = judgeIfTrainingOnTrack(params, earlyStoppingRounds, metrics, iter);
if (!onTrack) {
String reversedDirection = getReversedDirection(params);
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
Rabit.trackerPrint(String.format(
"early stopping after %d %s rounds", earlyStoppingRounds, reversedDirection));
"early stopping after %d rounds away from the best iteration",
earlyStoppingRounds));
break;
}
}
@@ -214,42 +237,11 @@ public class XGBoost {
return booster;
}
static boolean judgeIfTrainingOnTrack(
Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) {
boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params);
boolean onTrack = false;
// we don't need to consider iterations before reaching to `earlyStoppingRounds`th iteration
if (iter < earlyStoppingRounds - 1) {
return true;
}
for (int metricsId = metrics.length == 1 ? 0 : 1; metricsId < metrics.length; metricsId++) {
float[] criterion = metrics[metricsId];
for (int shift = 0; shift < earlyStoppingRounds - 1; shift++) {
// the initial value of onTrack is false and if the metrics in any of `earlyStoppingRounds`
// iterations goes to the expected direction, we should consider these `earlyStoppingRounds`
// as `onTrack`
onTrack |= maximizeEvaluationMetrics ?
criterion[iter - shift] >= criterion[iter - shift - 1] :
criterion[iter - shift] <= criterion[iter - shift - 1];
}
if (!onTrack) {
return false;
}
}
return onTrack;
static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIteration) {
return iter - bestIteration >= earlyStoppingRounds;
}
private static String getReversedDirection(Map<String, Object> params) {
String reversedDirection = null;
if (Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
reversedDirection = "descending";
} else if (!Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
reversedDirection = "ascending";
}
return reversedDirection;
}
private static boolean getMetricsExpectedDirection(Map<String, Object> params) {
private static boolean isMaximizeEvaluation(Map<String, Object> params) {
try {
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
assert(maximize != null);