[jvm-packages] Implemented early stopping (#2710)
* Allowed subsampling test from the training data frame/RDD The implementation requires storing 1 - trainTestRatio points in memory to make the sampling work. An alternative approach would be to construct the full DMatrix and then slice it deterministically into train/test. The peak memory consumption of such scenario, however, is twice the dataset size. * Removed duplication from 'XGBoost.train' Scala callers can (and should) use names to supply a subset of parameters. Method overloading is not required. * Reuse XGBoost seed parameter to stabilize train/test splitting * Added early stopping support to non-distributed XGBoost Closes #1544 * Added early-stopping to distributed XGBoost * Moved construction of 'watches' into a separate method This commit also fixes the handling of 'baseMargin' which previously was not added to the validation matrix. * Addressed review comments
This commit is contained in:
@@ -201,6 +201,12 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
*/
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
|
||||
throws XGBoostError {
|
||||
// Hopefully, a tiny redundant allocation wouldn't hurt.
|
||||
return evalSet(evalMatrixs, evalNames, eval, new float[evalNames.length]);
|
||||
}
|
||||
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval,
|
||||
float[] metricsOut) throws XGBoostError {
|
||||
String evalInfo = "";
|
||||
for (int i = 0; i < evalNames.length; i++) {
|
||||
String evalName = evalNames[i];
|
||||
@@ -208,6 +214,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
float evalResult = eval.eval(predict(evalMat), evalMat);
|
||||
String evalMetric = eval.getMetric();
|
||||
evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
|
||||
metricsOut[i] = evalResult;
|
||||
}
|
||||
return evalInfo;
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ public class XGBoost {
|
||||
Map<String, DMatrix> watches,
|
||||
IObjective obj,
|
||||
IEvaluation eval) throws XGBoostError {
|
||||
return train(dtrain, params, round, watches, null, obj, eval);
|
||||
return train(dtrain, params, round, watches, null, obj, eval, 0);
|
||||
}
|
||||
|
||||
public static Booster train(
|
||||
@@ -74,7 +74,8 @@ public class XGBoost {
|
||||
Map<String, DMatrix> watches,
|
||||
float[][] metrics,
|
||||
IObjective obj,
|
||||
IEvaluation eval) throws XGBoostError {
|
||||
IEvaluation eval,
|
||||
int earlyStoppingRound) throws XGBoostError {
|
||||
|
||||
//collect eval matrixs
|
||||
String[] evalNames;
|
||||
@@ -89,6 +90,7 @@ public class XGBoost {
|
||||
|
||||
evalNames = names.toArray(new String[names.size()]);
|
||||
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
||||
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
|
||||
|
||||
//collect all data matrixs
|
||||
DMatrix[] allMats;
|
||||
@@ -120,19 +122,27 @@ public class XGBoost {
|
||||
|
||||
//evaluation
|
||||
if (evalMats.length > 0) {
|
||||
float[] metricsOut = new float[evalMats.length];
|
||||
String evalInfo;
|
||||
if (eval != null) {
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, eval);
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, eval, metricsOut);
|
||||
} else {
|
||||
if (metrics == null) {
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter);
|
||||
} else {
|
||||
float[] m = new float[evalMats.length];
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter, m);
|
||||
for (int i = 0; i < m.length; i++) {
|
||||
metrics[i][iter] = m[i];
|
||||
}
|
||||
}
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut);
|
||||
}
|
||||
for (int i = 0; i < metricsOut.length; i++) {
|
||||
metrics[i][iter] = metricsOut[i];
|
||||
}
|
||||
|
||||
boolean decreasing = true;
|
||||
float[] criterion = metrics[metrics.length - 1];
|
||||
for (int shift = 0; shift < Math.min(iter, earlyStoppingRound) - 1; shift++) {
|
||||
decreasing &= criterion[iter - shift] <= criterion[iter - shift - 1];
|
||||
}
|
||||
|
||||
if (!decreasing) {
|
||||
Rabit.trackerPrint(String.format(
|
||||
"early stopping after %d decreasing rounds", earlyStoppingRound));
|
||||
break;
|
||||
}
|
||||
if (Rabit.getRank() == 0) {
|
||||
Rabit.trackerPrint(evalInfo + '\n');
|
||||
|
||||
@@ -36,6 +36,9 @@ object XGBoost {
|
||||
* performance on the validation set.
|
||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||
* iteration
|
||||
* @param earlyStoppingRound if non-zero, training would be stopped
|
||||
* after a specified number of consecutive
|
||||
* increases in any evaluation metric.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @return The trained booster.
|
||||
@@ -45,44 +48,20 @@ object XGBoost {
|
||||
dtrain: DMatrix,
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
watches: Map[String, DMatrix],
|
||||
metrics: Array[Array[Float]],
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait): Booster = {
|
||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||
watches: Map[String, DMatrix] = Map(),
|
||||
metrics: Array[Array[Float]] = null,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
earlyStoppingRound: Int = 0): Booster = {
|
||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).map{
|
||||
case (key: String, value) => (key, value.toString)
|
||||
}.toMap[String, AnyRef].asJava,
|
||||
round, jWatches.asJava, metrics, obj, eval)
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
round, jWatches, metrics, obj, eval, earlyStoppingRound)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @return The trained booster.
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def train(
|
||||
dtrain: DMatrix,
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
watches: Map[String, DMatrix] = Map[String, DMatrix](),
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): Booster = {
|
||||
train(dtrain, params, round, watches, null, obj, eval)
|
||||
}
|
||||
|
||||
/**
|
||||
* Cross-validation with given parameters.
|
||||
*
|
||||
|
||||
@@ -23,11 +23,10 @@ import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
@@ -37,16 +36,9 @@ import org.junit.Test;
|
||||
*/
|
||||
public class BoosterImplTest {
|
||||
public static class EvalError implements IEvaluation {
|
||||
private static final Log logger = LogFactory.getLog(EvalError.class);
|
||||
|
||||
String evalMetric = "custom_error";
|
||||
|
||||
public EvalError() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetric() {
|
||||
return evalMetric;
|
||||
return "custom_error";
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -56,8 +48,7 @@ public class BoosterImplTest {
|
||||
try {
|
||||
labels = dmat.getLabel();
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return -1f;
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
int nrow = predicts.length;
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
@@ -150,11 +141,55 @@ public class BoosterImplTest {
|
||||
TestCase.assertTrue("loadedPredictErr:" + loadedPredictError, loadedPredictError < 0.1f);
|
||||
}
|
||||
|
||||
private static class IncreasingEval implements IEvaluation {
|
||||
private int value = 0;
|
||||
|
||||
@Override
|
||||
public String getMetric() {
|
||||
return "inc";
|
||||
}
|
||||
|
||||
@Override
|
||||
public float eval(float[][] predicts, DMatrix dmat) {
|
||||
return value++;
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBoosterEarlyStop() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
// testBoosterWithFastHistogram(trainMat, testMat);
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
}
|
||||
};
|
||||
Map<String, DMatrix> watches = new LinkedHashMap<>();
|
||||
watches.put("training", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
final int round = 10;
|
||||
int earlyStoppingRound = 2;
|
||||
float[][] metrics = new float[watches.size()][round];
|
||||
XGBoost.train(trainMat, paramMap, round, watches, metrics, null, new IncreasingEval(),
|
||||
earlyStoppingRound);
|
||||
|
||||
// Make sure we've stopped early.
|
||||
for (int w = 0; w < watches.size(); w++) {
|
||||
for (int r = earlyStoppingRound + 1; r < round; r++) {
|
||||
TestCase.assertEquals(0.0f, metrics[w][r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void testWithFastHisto(DMatrix trainingSet, Map<String, DMatrix> watches, int round,
|
||||
Map<String, Object> paramMap, float threshold) throws XGBoostError {
|
||||
float[][] metrics = new float[watches.size()][round];
|
||||
Booster booster = XGBoost.train(trainingSet, paramMap, round, watches,
|
||||
metrics, null, null);
|
||||
metrics, null, null, 0);
|
||||
for (int i = 0; i < metrics.length; i++)
|
||||
for (int j = 1; j < metrics[i].length; j++) {
|
||||
TestCase.assertTrue(metrics[i][j] >= metrics[i][j - 1]);
|
||||
|
||||
@@ -74,7 +74,7 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
val watches = List("train" -> trainMat, "test" -> testMat).toMap
|
||||
|
||||
val round = 2
|
||||
XGBoost.train(trainMat, paramMap, round, watches, null, null)
|
||||
XGBoost.train(trainMat, paramMap, round, watches)
|
||||
}
|
||||
|
||||
private def trainBoosterWithFastHisto(
|
||||
@@ -84,7 +84,7 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
paramMap: Map[String, String],
|
||||
threshold: Float): Booster = {
|
||||
val metrics = Array.fill(watches.size, round)(0.0f)
|
||||
val booster = XGBoost.train(trainMat, paramMap, round, watches, metrics, null, null)
|
||||
val booster = XGBoost.train(trainMat, paramMap, round, watches, metrics)
|
||||
for (i <- 0 until watches.size; j <- 1 until metrics(i).length) {
|
||||
assert(metrics(i)(j) >= metrics(i)(j - 1))
|
||||
}
|
||||
@@ -143,7 +143,7 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
||||
val round = 2
|
||||
val nfold = 5
|
||||
XGBoost.crossValidation(trainMat, params, round, nfold, null, null, null)
|
||||
XGBoost.crossValidation(trainMat, params, round, nfold)
|
||||
}
|
||||
|
||||
test("test with fast histo depthwise") {
|
||||
|
||||
Reference in New Issue
Block a user