[jvm-packages] Implemented early stopping (#2710)

* Allowed subsampling test from the training data frame/RDD

The implementation requires storing 1 - trainTestRatio points in memory
to make the sampling work.

An alternative approach would be to construct the full DMatrix and then
slice it deterministically into train/test. The peak memory consumption
of such scenario, however, is twice the dataset size.

* Removed duplication from 'XGBoost.train'

Scala callers can (and should) use names to supply a subset of
parameters. Method overloading is not required.

* Reuse XGBoost seed parameter to stabilize train/test splitting

* Added early stopping support to non-distributed XGBoost

Closes #1544

* Added early-stopping to distributed XGBoost

* Moved construction of 'watches' into a separate method

This commit also fixes the handling of 'baseMargin' which previously
was not added to the validation matrix.

* Addressed review comments
This commit is contained in:
Sergei Lebedev
2017-09-29 21:06:22 +02:00
committed by Nan Zhu
parent 74db9757b3
commit 69c3b78a29
15 changed files with 191 additions and 91 deletions

View File

@@ -23,11 +23,10 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import junit.framework.TestCase;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Test;
/**
@@ -37,16 +36,9 @@ import org.junit.Test;
*/
public class BoosterImplTest {
public static class EvalError implements IEvaluation {
private static final Log logger = LogFactory.getLog(EvalError.class);
String evalMetric = "custom_error";
public EvalError() {
}
@Override
public String getMetric() {
return evalMetric;
return "custom_error";
}
@Override
@@ -56,8 +48,7 @@ public class BoosterImplTest {
try {
labels = dmat.getLabel();
} catch (XGBoostError ex) {
logger.error(ex);
return -1f;
throw new RuntimeException(ex);
}
int nrow = predicts.length;
for (int i = 0; i < nrow; i++) {
@@ -150,11 +141,55 @@ public class BoosterImplTest {
TestCase.assertTrue("loadedPredictErr:" + loadedPredictError, loadedPredictError < 0.1f);
}
private static class IncreasingEval implements IEvaluation {
private int value = 0;
@Override
public String getMetric() {
return "inc";
}
@Override
public float eval(float[][] predicts, DMatrix dmat) {
return value++;
}
}
@Test
public void testBoosterEarlyStop() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
// testBoosterWithFastHistogram(trainMat, testMat);
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 3);
put("silent", 1);
put("objective", "binary:logistic");
}
};
Map<String, DMatrix> watches = new LinkedHashMap<>();
watches.put("training", trainMat);
watches.put("test", testMat);
final int round = 10;
int earlyStoppingRound = 2;
float[][] metrics = new float[watches.size()][round];
XGBoost.train(trainMat, paramMap, round, watches, metrics, null, new IncreasingEval(),
earlyStoppingRound);
// Make sure we've stopped early.
for (int w = 0; w < watches.size(); w++) {
for (int r = earlyStoppingRound + 1; r < round; r++) {
TestCase.assertEquals(0.0f, metrics[w][r]);
}
}
}
private void testWithFastHisto(DMatrix trainingSet, Map<String, DMatrix> watches, int round,
Map<String, Object> paramMap, float threshold) throws XGBoostError {
float[][] metrics = new float[watches.size()][round];
Booster booster = XGBoost.train(trainingSet, paramMap, round, watches,
metrics, null, null);
metrics, null, null, 0);
for (int i = 0; i < metrics.length; i++)
for (int j = 1; j < metrics[i].length; j++) {
TestCase.assertTrue(metrics[i][j] >= metrics[i][j - 1]);

View File

@@ -74,7 +74,7 @@ class ScalaBoosterImplSuite extends FunSuite {
val watches = List("train" -> trainMat, "test" -> testMat).toMap
val round = 2
XGBoost.train(trainMat, paramMap, round, watches, null, null)
XGBoost.train(trainMat, paramMap, round, watches)
}
private def trainBoosterWithFastHisto(
@@ -84,7 +84,7 @@ class ScalaBoosterImplSuite extends FunSuite {
paramMap: Map[String, String],
threshold: Float): Booster = {
val metrics = Array.fill(watches.size, round)(0.0f)
val booster = XGBoost.train(trainMat, paramMap, round, watches, metrics, null, null)
val booster = XGBoost.train(trainMat, paramMap, round, watches, metrics)
for (i <- 0 until watches.size; j <- 1 until metrics(i).length) {
assert(metrics(i)(j) >= metrics(i)(j - 1))
}
@@ -143,7 +143,7 @@ class ScalaBoosterImplSuite extends FunSuite {
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
val round = 2
val nfold = 5
XGBoost.crossValidation(trainMat, params, round, nfold, null, null, null)
XGBoost.crossValidation(trainMat, params, round, nfold)
}
test("test with fast histo depthwise") {