[Blocking][jvm-packages] fix the early stopping feature (#3808)
* add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * temp * add method for classifier and regressor * update tutorial * address the comments * update
This commit is contained in:
parent
e26b5d63b2
commit
4ae225a08d
@ -183,6 +183,15 @@ After we set XGBoostClassifier parameters and feature/label column, we can build
|
|||||||
|
|
||||||
val xgbClassificationModel = xgbClassifier.fit(xgbInput)
|
val xgbClassificationModel = xgbClassifier.fit(xgbInput)
|
||||||
|
|
||||||
|
Early Stopping
|
||||||
|
----------------
|
||||||
|
|
||||||
|
Early stopping is a feature to prevent the unnecessary training iterations. By specifying ``num_early_stopping_rounds`` or directly call ``setNumEarlyStoppingRounds`` over a XGBoostClassifier or XGBoostRegressor, we can define number of rounds for the evaluation metric going to the unexpected direction to tolerate before stopping the training.
|
||||||
|
|
||||||
|
In additional to ``num_early_stopping_rounds``, you also need to define ``maximize_evaluation_metrics`` or call ``setMaximizeEvaluationMetrics`` to specify whether you want to maximize or minimize the metrics in training.
|
||||||
|
|
||||||
|
After specifying these two parameters, the training would stop when the metrics goes to the other direction against the one specified by ``maximize_evaluation_metrics`` for ``num_early_stopping_rounds`` iterations.
|
||||||
|
|
||||||
Prediction
|
Prediction
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
|||||||
@ -132,6 +132,11 @@ object XGBoost extends Serializable {
|
|||||||
try {
|
try {
|
||||||
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
||||||
.map(_.toString.toInt).getOrElse(0)
|
.map(_.toString.toInt).getOrElse(0)
|
||||||
|
if (numEarlyStoppingRounds > 0) {
|
||||||
|
if (!params.contains("maximize_evaluation_metrics")) {
|
||||||
|
throw new IllegalArgumentException("maximize_evaluation_metrics has to be specified")
|
||||||
|
}
|
||||||
|
}
|
||||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
||||||
val booster = SXGBoost.train(watches.train, params, round,
|
val booster = SXGBoost.train(watches.train, params, round,
|
||||||
watches.toMap, metrics, obj, eval,
|
watches.toMap, metrics, obj, eval,
|
||||||
|
|||||||
@ -140,6 +140,9 @@ class XGBoostClassifier (
|
|||||||
|
|
||||||
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
||||||
|
|
||||||
|
def setMaximizeEvaluationMetrics(value: Boolean): this.type =
|
||||||
|
set(maximizeEvaluationMetrics, value)
|
||||||
|
|
||||||
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
|
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
|
||||||
|
|
||||||
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
|
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
|
||||||
|
|||||||
@ -140,6 +140,9 @@ class XGBoostRegressor (
|
|||||||
|
|
||||||
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
||||||
|
|
||||||
|
def setMaximizeEvaluationMetrics(value: Boolean): this.type =
|
||||||
|
set(maximizeEvaluationMetrics, value)
|
||||||
|
|
||||||
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
|
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
|
||||||
|
|
||||||
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
|
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
|
||||||
|
|||||||
@ -87,6 +87,13 @@ private[spark] trait LearningTaskParams extends Params {
|
|||||||
|
|
||||||
final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds)
|
final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds)
|
||||||
|
|
||||||
|
|
||||||
|
final val maximizeEvaluationMetrics = new BooleanParam(this, "maximizeEvaluationMetrics",
|
||||||
|
"define the expected optimization to the evaluation metrics, true to maximize otherwise" +
|
||||||
|
" minimize it")
|
||||||
|
|
||||||
|
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
|
||||||
|
|
||||||
setDefault(objective -> "reg:linear", baseScore -> 0.5,
|
setDefault(objective -> "reg:linear", baseScore -> 0.5,
|
||||||
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0)
|
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -118,9 +118,9 @@ public class XGBoost {
|
|||||||
* performance on the validation set.
|
* performance on the validation set.
|
||||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||||
* iteration
|
* iteration
|
||||||
* @param earlyStoppingRound if non-zero, training would be stopped
|
* @param earlyStoppingRounds if non-zero, training would be stopped
|
||||||
* after a specified number of consecutive
|
* after a specified number of consecutive
|
||||||
* increases in any evaluation metric.
|
* goes to the unexpected direction in any evaluation metric.
|
||||||
* @param obj customized objective
|
* @param obj customized objective
|
||||||
* @param eval customized evaluation
|
* @param eval customized evaluation
|
||||||
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||||
@ -134,7 +134,7 @@ public class XGBoost {
|
|||||||
float[][] metrics,
|
float[][] metrics,
|
||||||
IObjective obj,
|
IObjective obj,
|
||||||
IEvaluation eval,
|
IEvaluation eval,
|
||||||
int earlyStoppingRound,
|
int earlyStoppingRounds,
|
||||||
Booster booster) throws XGBoostError {
|
Booster booster) throws XGBoostError {
|
||||||
|
|
||||||
//collect eval matrixs
|
//collect eval matrixs
|
||||||
@ -196,17 +196,14 @@ public class XGBoost {
|
|||||||
for (int i = 0; i < metricsOut.length; i++) {
|
for (int i = 0; i < metricsOut.length; i++) {
|
||||||
metrics[i][iter] = metricsOut[i];
|
metrics[i][iter] = metricsOut[i];
|
||||||
}
|
}
|
||||||
|
if (earlyStoppingRounds > 0) {
|
||||||
boolean decreasing = true;
|
boolean onTrack = judgeIfTrainingOnTrack(params, earlyStoppingRounds, metrics, iter);
|
||||||
float[] criterion = metrics[metrics.length - 1];
|
if (!onTrack) {
|
||||||
for (int shift = 0; shift < Math.min(iter, earlyStoppingRound) - 1; shift++) {
|
String reversedDirection = getReversedDirection(params);
|
||||||
decreasing &= criterion[iter - shift] <= criterion[iter - shift - 1];
|
Rabit.trackerPrint(String.format(
|
||||||
}
|
"early stopping after %d %s rounds", earlyStoppingRounds, reversedDirection));
|
||||||
|
break;
|
||||||
if (!decreasing) {
|
}
|
||||||
Rabit.trackerPrint(String.format(
|
|
||||||
"early stopping after %d decreasing rounds", earlyStoppingRound));
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
if (Rabit.getRank() == 0) {
|
if (Rabit.getRank() == 0) {
|
||||||
Rabit.trackerPrint(evalInfo + '\n');
|
Rabit.trackerPrint(evalInfo + '\n');
|
||||||
@ -217,6 +214,41 @@ public class XGBoost {
|
|||||||
return booster;
|
return booster;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static boolean judgeIfTrainingOnTrack(
|
||||||
|
Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) {
|
||||||
|
boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params);
|
||||||
|
boolean onTrack = false;
|
||||||
|
float[] criterion = metrics[metrics.length - 1];
|
||||||
|
for (int shift = 0; shift < Math.min(iter, earlyStoppingRounds) - 1; shift++) {
|
||||||
|
onTrack |= maximizeEvaluationMetrics ?
|
||||||
|
criterion[iter - shift] >= criterion[iter - shift - 1] :
|
||||||
|
criterion[iter - shift] <= criterion[iter - shift - 1];
|
||||||
|
}
|
||||||
|
return onTrack;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String getReversedDirection(Map<String, Object> params) {
|
||||||
|
String reversedDirection = null;
|
||||||
|
if (Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
|
||||||
|
reversedDirection = "descending";
|
||||||
|
} else if (!Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
|
||||||
|
reversedDirection = "ascending";
|
||||||
|
}
|
||||||
|
return reversedDirection;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean getMetricsExpectedDirection(Map<String, Object> params) {
|
||||||
|
try {
|
||||||
|
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
|
||||||
|
assert(maximize != null);
|
||||||
|
return Boolean.valueOf(maximize);
|
||||||
|
} catch (Exception ex) {
|
||||||
|
logger.error("maximize_evaluation_metrics has to be specified for enabling early stop," +
|
||||||
|
" allowed value: true/false", ex);
|
||||||
|
throw ex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cross-validation with given parameters.
|
* Cross-validation with given parameters.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -152,6 +152,66 @@ public class BoosterImplTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDescendMetrics() {
|
||||||
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
{
|
||||||
|
put("max_depth", 3);
|
||||||
|
put("silent", 1);
|
||||||
|
put("objective", "binary:logistic");
|
||||||
|
put("maximize_evaluation_metrics", "false");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
float[][] metrics = new float[1][5];
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
metrics[0][i] = i;
|
||||||
|
}
|
||||||
|
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
||||||
|
TestCase.assertFalse(onTrack);
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
metrics[0][i] = 5 - i;
|
||||||
|
}
|
||||||
|
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
||||||
|
TestCase.assertTrue(onTrack);
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
metrics[0][i] = 5 - i;
|
||||||
|
}
|
||||||
|
metrics[0][0] = 1;
|
||||||
|
metrics[0][2] = 5;
|
||||||
|
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
||||||
|
TestCase.assertTrue(onTrack);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAscendMetrics() {
|
||||||
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
{
|
||||||
|
put("max_depth", 3);
|
||||||
|
put("silent", 1);
|
||||||
|
put("objective", "binary:logistic");
|
||||||
|
put("maximize_evaluation_metrics", "true");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
float[][] metrics = new float[1][5];
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
metrics[0][i] = i;
|
||||||
|
}
|
||||||
|
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
||||||
|
TestCase.assertTrue(onTrack);
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
metrics[0][i] = 5 - i;
|
||||||
|
}
|
||||||
|
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
||||||
|
TestCase.assertFalse(onTrack);
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
metrics[0][i] = i;
|
||||||
|
}
|
||||||
|
metrics[0][0] = 6;
|
||||||
|
metrics[0][2] = 1;
|
||||||
|
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
|
||||||
|
TestCase.assertTrue(onTrack);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBoosterEarlyStop() throws XGBoostError, IOException {
|
public void testBoosterEarlyStop() throws XGBoostError, IOException {
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
@ -162,6 +222,7 @@ public class BoosterImplTest {
|
|||||||
put("max_depth", 3);
|
put("max_depth", 3);
|
||||||
put("silent", 1);
|
put("silent", 1);
|
||||||
put("objective", "binary:logistic");
|
put("objective", "binary:logistic");
|
||||||
|
put("maximize_evaluation_metrics", "false");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Map<String, DMatrix> watches = new LinkedHashMap<>();
|
Map<String, DMatrix> watches = new LinkedHashMap<>();
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user