[jvm-packages] Implemented early stopping (#2710)

* Allowed subsampling test from the training data frame/RDD

The implementation requires storing 1 - trainTestRatio points in memory
to make the sampling work.

An alternative approach would be to construct the full DMatrix and then
slice it deterministically into train/test. The peak memory consumption
of such scenario, however, is twice the dataset size.

* Removed duplication from 'XGBoost.train'

Scala callers can (and should) use names to supply a subset of
parameters. Method overloading is not required.

* Reuse XGBoost seed parameter to stabilize train/test splitting

* Added early stopping support to non-distributed XGBoost

Closes #1544

* Added early-stopping to distributed XGBoost

* Moved construction of 'watches' into a separate method

This commit also fixes the handling of 'baseMargin' which previously
was not added to the validation matrix.

* Addressed review comments
This commit is contained in:
Sergei Lebedev
2017-09-29 21:06:22 +02:00
committed by Nan Zhu
parent 74db9757b3
commit 69c3b78a29
15 changed files with 191 additions and 91 deletions

View File

@@ -201,6 +201,12 @@ public class Booster implements Serializable, KryoSerializable {
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
throws XGBoostError {
// Hopefully, a tiny redundant allocation wouldn't hurt.
return evalSet(evalMatrixs, evalNames, eval, new float[evalNames.length]);
}
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval,
float[] metricsOut) throws XGBoostError {
String evalInfo = "";
for (int i = 0; i < evalNames.length; i++) {
String evalName = evalNames[i];
@@ -208,6 +214,7 @@ public class Booster implements Serializable, KryoSerializable {
float evalResult = eval.eval(predict(evalMat), evalMat);
String evalMetric = eval.getMetric();
evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
metricsOut[i] = evalResult;
}
return evalInfo;
}

View File

@@ -64,7 +64,7 @@ public class XGBoost {
Map<String, DMatrix> watches,
IObjective obj,
IEvaluation eval) throws XGBoostError {
return train(dtrain, params, round, watches, null, obj, eval);
return train(dtrain, params, round, watches, null, obj, eval, 0);
}
public static Booster train(
@@ -74,7 +74,8 @@ public class XGBoost {
Map<String, DMatrix> watches,
float[][] metrics,
IObjective obj,
IEvaluation eval) throws XGBoostError {
IEvaluation eval,
int earlyStoppingRound) throws XGBoostError {
//collect eval matrixs
String[] evalNames;
@@ -89,6 +90,7 @@ public class XGBoost {
evalNames = names.toArray(new String[names.size()]);
evalMats = mats.toArray(new DMatrix[mats.size()]);
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
//collect all data matrixs
DMatrix[] allMats;
@@ -120,19 +122,27 @@ public class XGBoost {
//evaluation
if (evalMats.length > 0) {
float[] metricsOut = new float[evalMats.length];
String evalInfo;
if (eval != null) {
evalInfo = booster.evalSet(evalMats, evalNames, eval);
evalInfo = booster.evalSet(evalMats, evalNames, eval, metricsOut);
} else {
if (metrics == null) {
evalInfo = booster.evalSet(evalMats, evalNames, iter);
} else {
float[] m = new float[evalMats.length];
evalInfo = booster.evalSet(evalMats, evalNames, iter, m);
for (int i = 0; i < m.length; i++) {
metrics[i][iter] = m[i];
}
}
evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut);
}
for (int i = 0; i < metricsOut.length; i++) {
metrics[i][iter] = metricsOut[i];
}
boolean decreasing = true;
float[] criterion = metrics[metrics.length - 1];
for (int shift = 0; shift < Math.min(iter, earlyStoppingRound) - 1; shift++) {
decreasing &= criterion[iter - shift] <= criterion[iter - shift - 1];
}
if (!decreasing) {
Rabit.trackerPrint(String.format(
"early stopping after %d decreasing rounds", earlyStoppingRound));
break;
}
if (Rabit.getRank() == 0) {
Rabit.trackerPrint(evalInfo + '\n');

View File

@@ -36,6 +36,9 @@ object XGBoost {
* performance on the validation set.
* @param metrics array containing the evaluation metrics for each matrix in watches for each
* iteration
* @param earlyStoppingRound if non-zero, training would be stopped
* after a specified number of consecutive
* increases in any evaluation metric.
* @param obj customized objective
* @param eval customized evaluation
* @return The trained booster.
@@ -45,44 +48,20 @@ object XGBoost {
dtrain: DMatrix,
params: Map[String, Any],
round: Int,
watches: Map[String, DMatrix],
metrics: Array[Array[Float]],
obj: ObjectiveTrait,
eval: EvalTrait): Booster = {
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
watches: Map[String, DMatrix] = Map(),
metrics: Array[Array[Float]] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
earlyStoppingRound: Int = 0): Booster = {
val jWatches = watches.mapValues(_.jDMatrix).asJava
val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).map{
case (key: String, value) => (key, value.toString)
}.toMap[String, AnyRef].asJava,
round, jWatches.asJava, metrics, obj, eval)
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
round, jWatches, metrics, obj, eval, earlyStoppingRound)
new Booster(xgboostInJava)
}
/**
* Train a booster given parameters.
*
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param obj customized objective
* @param eval customized evaluation
* @return The trained booster.
*/
@throws(classOf[XGBoostError])
def train(
dtrain: DMatrix,
params: Map[String, Any],
round: Int,
watches: Map[String, DMatrix] = Map[String, DMatrix](),
obj: ObjectiveTrait = null,
eval: EvalTrait = null): Booster = {
train(dtrain, params, round, watches, null, obj, eval)
}
/**
* Cross-validation with given parameters.
*