[Blocking][jvm-packages] fix the early stopping feature (#3808)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* temp

* add method for classifier and regressor

* update tutorial

* address the comments

* update
This commit is contained in:
Nan Zhu 2018-10-23 14:53:13 -07:00 committed by GitHub
parent e26b5d63b2
commit 4ae225a08d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 134 additions and 14 deletions

View File

@ -183,6 +183,15 @@ After we set XGBoostClassifier parameters and feature/label column, we can build
val xgbClassificationModel = xgbClassifier.fit(xgbInput) val xgbClassificationModel = xgbClassifier.fit(xgbInput)
Early Stopping
----------------
Early stopping is a feature to prevent the unnecessary training iterations. By specifying ``num_early_stopping_rounds`` or directly call ``setNumEarlyStoppingRounds`` over a XGBoostClassifier or XGBoostRegressor, we can define number of rounds for the evaluation metric going to the unexpected direction to tolerate before stopping the training.
In additional to ``num_early_stopping_rounds``, you also need to define ``maximize_evaluation_metrics`` or call ``setMaximizeEvaluationMetrics`` to specify whether you want to maximize or minimize the metrics in training.
After specifying these two parameters, the training would stop when the metrics goes to the other direction against the one specified by ``maximize_evaluation_metrics`` for ``num_early_stopping_rounds`` iterations.
Prediction Prediction
========== ==========

View File

@ -132,6 +132,11 @@ object XGBoost extends Serializable {
try { try {
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds") val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
.map(_.toString.toInt).getOrElse(0) .map(_.toString.toInt).getOrElse(0)
if (numEarlyStoppingRounds > 0) {
if (!params.contains("maximize_evaluation_metrics")) {
throw new IllegalArgumentException("maximize_evaluation_metrics has to be specified")
}
}
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round)) val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
val booster = SXGBoost.train(watches.train, params, round, val booster = SXGBoost.train(watches.train, params, round,
watches.toMap, metrics, obj, eval, watches.toMap, metrics, obj, eval,

View File

@ -140,6 +140,9 @@ class XGBoostClassifier (
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value) def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
def setMaximizeEvaluationMetrics(value: Boolean): this.type =
set(maximizeEvaluationMetrics, value)
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value) def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
def setCustomEval(value: EvalTrait): this.type = set(customEval, value) def setCustomEval(value: EvalTrait): this.type = set(customEval, value)

View File

@ -140,6 +140,9 @@ class XGBoostRegressor (
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value) def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
def setMaximizeEvaluationMetrics(value: Boolean): this.type =
set(maximizeEvaluationMetrics, value)
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value) def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
def setCustomEval(value: EvalTrait): this.type = set(customEval, value) def setCustomEval(value: EvalTrait): this.type = set(customEval, value)

View File

@ -87,6 +87,13 @@ private[spark] trait LearningTaskParams extends Params {
final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds) final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds)
final val maximizeEvaluationMetrics = new BooleanParam(this, "maximizeEvaluationMetrics",
"define the expected optimization to the evaluation metrics, true to maximize otherwise" +
" minimize it")
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
setDefault(objective -> "reg:linear", baseScore -> 0.5, setDefault(objective -> "reg:linear", baseScore -> 0.5,
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0) trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0)
} }

View File

@ -118,9 +118,9 @@ public class XGBoost {
* performance on the validation set. * performance on the validation set.
* @param metrics array containing the evaluation metrics for each matrix in watches for each * @param metrics array containing the evaluation metrics for each matrix in watches for each
* iteration * iteration
* @param earlyStoppingRound if non-zero, training would be stopped * @param earlyStoppingRounds if non-zero, training would be stopped
* after a specified number of consecutive * after a specified number of consecutive
* increases in any evaluation metric. * goes to the unexpected direction in any evaluation metric.
* @param obj customized objective * @param obj customized objective
* @param eval customized evaluation * @param eval customized evaluation
* @param booster train from scratch if set to null; train from an existing booster if not null. * @param booster train from scratch if set to null; train from an existing booster if not null.
@ -134,7 +134,7 @@ public class XGBoost {
float[][] metrics, float[][] metrics,
IObjective obj, IObjective obj,
IEvaluation eval, IEvaluation eval,
int earlyStoppingRound, int earlyStoppingRounds,
Booster booster) throws XGBoostError { Booster booster) throws XGBoostError {
//collect eval matrixs //collect eval matrixs
@ -196,18 +196,15 @@ public class XGBoost {
for (int i = 0; i < metricsOut.length; i++) { for (int i = 0; i < metricsOut.length; i++) {
metrics[i][iter] = metricsOut[i]; metrics[i][iter] = metricsOut[i];
} }
if (earlyStoppingRounds > 0) {
boolean decreasing = true; boolean onTrack = judgeIfTrainingOnTrack(params, earlyStoppingRounds, metrics, iter);
float[] criterion = metrics[metrics.length - 1]; if (!onTrack) {
for (int shift = 0; shift < Math.min(iter, earlyStoppingRound) - 1; shift++) { String reversedDirection = getReversedDirection(params);
decreasing &= criterion[iter - shift] <= criterion[iter - shift - 1];
}
if (!decreasing) {
Rabit.trackerPrint(String.format( Rabit.trackerPrint(String.format(
"early stopping after %d decreasing rounds", earlyStoppingRound)); "early stopping after %d %s rounds", earlyStoppingRounds, reversedDirection));
break; break;
} }
}
if (Rabit.getRank() == 0) { if (Rabit.getRank() == 0) {
Rabit.trackerPrint(evalInfo + '\n'); Rabit.trackerPrint(evalInfo + '\n');
} }
@ -217,6 +214,41 @@ public class XGBoost {
return booster; return booster;
} }
static boolean judgeIfTrainingOnTrack(
Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) {
boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params);
boolean onTrack = false;
float[] criterion = metrics[metrics.length - 1];
for (int shift = 0; shift < Math.min(iter, earlyStoppingRounds) - 1; shift++) {
onTrack |= maximizeEvaluationMetrics ?
criterion[iter - shift] >= criterion[iter - shift - 1] :
criterion[iter - shift] <= criterion[iter - shift - 1];
}
return onTrack;
}
private static String getReversedDirection(Map<String, Object> params) {
String reversedDirection = null;
if (Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
reversedDirection = "descending";
} else if (!Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
reversedDirection = "ascending";
}
return reversedDirection;
}
private static boolean getMetricsExpectedDirection(Map<String, Object> params) {
try {
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
assert(maximize != null);
return Boolean.valueOf(maximize);
} catch (Exception ex) {
logger.error("maximize_evaluation_metrics has to be specified for enabling early stop," +
" allowed value: true/false", ex);
throw ex;
}
}
/** /**
* Cross-validation with given parameters. * Cross-validation with given parameters.
* *

View File

@ -152,6 +152,66 @@ public class BoosterImplTest {
} }
} }
@Test
public void testDescendMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "false");
}
};
float[][] metrics = new float[1][5];
for (int i = 0; i < 5; i++) {
metrics[0][i] = i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
TestCase.assertFalse(onTrack);
for (int i = 0; i < 5; i++) {
metrics[0][i] = 5 - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
TestCase.assertTrue(onTrack);
for (int i = 0; i < 5; i++) {
metrics[0][i] = 5 - i;
}
metrics[0][0] = 1;
metrics[0][2] = 5;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
TestCase.assertTrue(onTrack);
}
@Test
public void testAscendMetrics() {
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "true");
}
};
float[][] metrics = new float[1][5];
for (int i = 0; i < 5; i++) {
metrics[0][i] = i;
}
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
TestCase.assertTrue(onTrack);
for (int i = 0; i < 5; i++) {
metrics[0][i] = 5 - i;
}
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
TestCase.assertFalse(onTrack);
for (int i = 0; i < 5; i++) {
metrics[0][i] = i;
}
metrics[0][0] = 6;
metrics[0][2] = 1;
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
TestCase.assertTrue(onTrack);
}
@Test @Test
public void testBoosterEarlyStop() throws XGBoostError, IOException { public void testBoosterEarlyStop() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
@ -162,6 +222,7 @@ public class BoosterImplTest {
put("max_depth", 3); put("max_depth", 3);
put("silent", 1); put("silent", 1);
put("objective", "binary:logistic"); put("objective", "binary:logistic");
put("maximize_evaluation_metrics", "false");
} }
}; };
Map<String, DMatrix> watches = new LinkedHashMap<>(); Map<String, DMatrix> watches = new LinkedHashMap<>();