[Blocking][jvm-packages] fix the early stopping feature (#3808)
* add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * temp * add method for classifier and regressor * update tutorial * address the comments * update
This commit is contained in:
@@ -118,9 +118,9 @@ public class XGBoost {
|
||||
* performance on the validation set.
|
||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||
* iteration
|
||||
* @param earlyStoppingRound if non-zero, training would be stopped
|
||||
* @param earlyStoppingRounds if non-zero, training would be stopped
|
||||
* after a specified number of consecutive
|
||||
* increases in any evaluation metric.
|
||||
* goes to the unexpected direction in any evaluation metric.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||
@@ -134,7 +134,7 @@ public class XGBoost {
|
||||
float[][] metrics,
|
||||
IObjective obj,
|
||||
IEvaluation eval,
|
||||
int earlyStoppingRound,
|
||||
int earlyStoppingRounds,
|
||||
Booster booster) throws XGBoostError {
|
||||
|
||||
//collect eval matrixs
|
||||
@@ -196,17 +196,14 @@ public class XGBoost {
|
||||
for (int i = 0; i < metricsOut.length; i++) {
|
||||
metrics[i][iter] = metricsOut[i];
|
||||
}
|
||||
|
||||
boolean decreasing = true;
|
||||
float[] criterion = metrics[metrics.length - 1];
|
||||
for (int shift = 0; shift < Math.min(iter, earlyStoppingRound) - 1; shift++) {
|
||||
decreasing &= criterion[iter - shift] <= criterion[iter - shift - 1];
|
||||
}
|
||||
|
||||
if (!decreasing) {
|
||||
Rabit.trackerPrint(String.format(
|
||||
"early stopping after %d decreasing rounds", earlyStoppingRound));
|
||||
break;
|
||||
if (earlyStoppingRounds > 0) {
|
||||
boolean onTrack = judgeIfTrainingOnTrack(params, earlyStoppingRounds, metrics, iter);
|
||||
if (!onTrack) {
|
||||
String reversedDirection = getReversedDirection(params);
|
||||
Rabit.trackerPrint(String.format(
|
||||
"early stopping after %d %s rounds", earlyStoppingRounds, reversedDirection));
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (Rabit.getRank() == 0) {
|
||||
Rabit.trackerPrint(evalInfo + '\n');
|
||||
@@ -217,6 +214,41 @@ public class XGBoost {
|
||||
return booster;
|
||||
}
|
||||
|
||||
static boolean judgeIfTrainingOnTrack(
|
||||
Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) {
|
||||
boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params);
|
||||
boolean onTrack = false;
|
||||
float[] criterion = metrics[metrics.length - 1];
|
||||
for (int shift = 0; shift < Math.min(iter, earlyStoppingRounds) - 1; shift++) {
|
||||
onTrack |= maximizeEvaluationMetrics ?
|
||||
criterion[iter - shift] >= criterion[iter - shift - 1] :
|
||||
criterion[iter - shift] <= criterion[iter - shift - 1];
|
||||
}
|
||||
return onTrack;
|
||||
}
|
||||
|
||||
private static String getReversedDirection(Map<String, Object> params) {
|
||||
String reversedDirection = null;
|
||||
if (Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
|
||||
reversedDirection = "descending";
|
||||
} else if (!Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
|
||||
reversedDirection = "ascending";
|
||||
}
|
||||
return reversedDirection;
|
||||
}
|
||||
|
||||
private static boolean getMetricsExpectedDirection(Map<String, Object> params) {
|
||||
try {
|
||||
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
|
||||
assert(maximize != null);
|
||||
return Boolean.valueOf(maximize);
|
||||
} catch (Exception ex) {
|
||||
logger.error("maximize_evaluation_metrics has to be specified for enabling early stop," +
|
||||
" allowed value: true/false", ex);
|
||||
throw ex;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cross-validation with given parameters.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user