From 69c3b78a292409f7d28867836f03f97ab4226b7d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 29 Sep 2017 21:06:22 +0200 Subject: [PATCH] [jvm-packages] Implemented early stopping (#2710) * Allowed subsampling test from the training data frame/RDD The implementation requires storing 1 - trainTestRatio points in memory to make the sampling work. An alternative approach would be to construct the full DMatrix and then slice it deterministically into train/test. The peak memory consumption of such scenario, however, is twice the dataset size. * Removed duplication from 'XGBoost.train' Scala callers can (and should) use names to supply a subset of parameters. Method overloading is not required. * Reuse XGBoost seed parameter to stabilize train/test splitting * Added early stopping support to non-distributed XGBoost Closes #1544 * Added early-stopping to distributed XGBoost * Moved construction of 'watches' into a separate method This commit also fixes the handling of 'baseMargin' which previously was not added to the validation matrix. * Addressed review comments --- .../scala/example/BasicWalkThrough.scala | 2 +- .../scala/example/CrossValidation.scala | 2 +- .../scala/example/CustomObjective.scala | 3 +- .../scala/example/ExternalMemory.scala | 2 +- .../example/GeneralizedLinearModel.scala | 2 +- .../dmlc/xgboost4j/scala/flink/XGBoost.scala | 5 +- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 83 ++++++++++++++----- .../scala/spark/params/GeneralParams.scala | 7 +- .../spark/params/LearningTaskParams.scala | 21 ++++- .../scala/spark/XGBoostDFSuite.scala | 4 +- .../java/ml/dmlc/xgboost4j/java/Booster.java | 7 ++ .../java/ml/dmlc/xgboost4j/java/XGBoost.java | 34 +++++--- .../ml/dmlc/xgboost4j/scala/XGBoost.scala | 43 +++------- .../dmlc/xgboost4j/java/BoosterImplTest.java | 61 +++++++++++--- .../scala/ScalaBoosterImplSuite.scala | 6 +- 15 files changed, 191 insertions(+), 91 deletions(-) diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala index ffc0c6a1d..e8481b047 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala @@ -85,7 +85,7 @@ object BasicWalkThrough { val watches2 = new mutable.HashMap[String, DMatrix] watches2 += "train" -> trainMax2 watches2 += "test" -> testMax2 - val booster3 = XGBoost.train(trainMax2, params.toMap, round, watches2.toMap, null, null) + val booster3 = XGBoost.train(trainMax2, params.toMap, round, watches2.toMap) val predicts3 = booster3.predict(testMax2) println(checkPredicts(predicts, predicts3)) } diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala index 8fd7581f4..62f8b461a 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala @@ -41,6 +41,6 @@ object CrossValidation { val metrics: Array[String] = null val evalHist: Array[String] = - XGBoost.crossValidation(trainMat, params.toMap, round, nfold, metrics, null, null) + XGBoost.crossValidation(trainMat, params.toMap, round, nfold, metrics) } } diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala index 58afd82e1..fe88423e7 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala @@ -151,7 +151,8 @@ object CustomObjective { val round = 2 // train a model val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap) - XGBoost.train(trainMat, params.toMap, round, watches.toMap, new LogRegObj, new EvalError) + XGBoost.train(trainMat, params.toMap, round, watches.toMap, + obj = new LogRegObj, eval = new EvalError) } } diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala index cdf3d3e9e..447c98295 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala @@ -54,6 +54,6 @@ object ExternalMemory { testMat.setBaseMargin(testPred) System.out.println("result of running from initial prediction") - val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null) + val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap) } } diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala index 966f04619..27ed98eca 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala @@ -52,7 +52,7 @@ object GeneralizedLinearModel { watches += "test" -> testMat val round = 4 - val booster = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null) + val booster = XGBoost.train(trainMat, params.toMap, 1, watches.toMap) val predicts = booster.predict(testMat) val eval = new CustomEval println(s"error=${eval.eval(predicts, testMat)}") diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala index fe34783cd..fa0d8b623 100644 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala @@ -55,7 +55,10 @@ object XGBoost { val trainMat = new DMatrix(dataIter, null) val watches = List("train" -> trainMat).toMap val round = 2 - val booster = XGBoostScala.train(trainMat, paramMap, round, watches, null, null) + val numEarlyStoppingRounds = paramMap.get("numEarlyStoppingRounds") + .map(_.toString.toInt).getOrElse(0) + val booster = XGBoostScala.train(trainMat, paramMap, round, watches, + earlyStoppingRound = numEarlyStoppingRounds) Rabit.shutdown() collector.collect(new XGBoostModel(booster)) } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 9187e2e41..f3ab0cd08 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -17,6 +17,7 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.mutable +import scala.util.Random import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.RabitTracker @@ -25,9 +26,9 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import org.apache.commons.logging.LogFactory import org.apache.hadoop.fs.{FSDataInputStream, Path} +import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset -import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} import org.apache.spark.{SparkContext, TaskContext} object TrackerConf { @@ -94,7 +95,7 @@ object XGBoost extends Serializable { } private[spark] def buildDistributedBoosters( - trainingSet: RDD[XGBLabeledPoint], + data: RDD[XGBLabeledPoint], params: Map[String, Any], rabitEnv: java.util.Map[String, String], numWorkers: Int, @@ -103,19 +104,19 @@ object XGBoost extends Serializable { eval: EvalTrait, useExternalMemory: Boolean, missing: Float): RDD[Booster] = { - val partitionedTrainingSet = if (trainingSet.getNumPartitions != numWorkers) { + val partitionedData = if (data.getNumPartitions != numWorkers) { logger.info(s"repartitioning training set to $numWorkers partitions") - trainingSet.repartition(numWorkers) + data.repartition(numWorkers) } else { - trainingSet + data } - val partitionedBaseMargin = partitionedTrainingSet.map(_.baseMargin) - val appName = partitionedTrainingSet.context.appName + val partitionedBaseMargin = partitionedData.map(_.baseMargin) + val appName = partitionedData.context.appName // to workaround the empty partitions in training dataset, // this might not be the best efficient implementation, see // (https://github.com/dmlc/xgboost/issues/1277) - partitionedTrainingSet.zipPartitions(partitionedBaseMargin) { (trainingPoints, baseMargins) => - if (trainingPoints.isEmpty) { + partitionedData.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) => + if (labeledPoints.isEmpty) { throw new XGBoostError( s"detected an empty partition in the training data, partition ID:" + s" ${TaskContext.getPartitionId()}") @@ -128,21 +129,20 @@ object XGBoost extends Serializable { } rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString) Rabit.init(rabitEnv) - val trainingMatrix = new DMatrix( - fromDenseToSparseLabeledPoints(trainingPoints, missing), cacheFileName) + val watches = Watches(params, + fromDenseToSparseLabeledPoints(labeledPoints, missing), + fromBaseMarginsToArray(baseMargins), cacheFileName) + try { - // TODO: use group attribute from the points. - if (params.contains("groupData") && params("groupData") != null) { - trainingMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]]( - TaskContext.getPartitionId()).toArray) - } - fromBaseMarginsToArray(baseMargins).foreach(trainingMatrix.setBaseMargin) - val booster = SXGBoost.train(trainingMatrix, params, round, - watches = Map("train" -> trainingMatrix), obj, eval) + val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds") + .map(_.toString.toInt).getOrElse(0) + val booster = SXGBoost.train(watches.train, params, round, + watches = watches.toMap, obj = obj, eval = eval, + earlyStoppingRound = numEarlyStoppingRounds) Iterator(booster) } finally { Rabit.shutdown() - trainingMatrix.delete() + watches.delete() } }.cache() } @@ -417,3 +417,46 @@ object XGBoost extends Serializable { } } } + +private class Watches private(val train: DMatrix, val test: DMatrix) { + def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test) + .filter { case (_, matrix) => matrix.rowNum > 0 } + + def size: Int = toMap.size + + def delete(): Unit = { + toMap.values.foreach(_.delete()) + } + + override def toString: String = toMap.toString +} + +private object Watches { + def apply( + params: Map[String, Any], + labeledPoints: Iterator[XGBLabeledPoint], + baseMarginsOpt: Option[Array[Float]], + cacheFileName: String): Watches = { + val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0) + val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime()) + val r = new Random(seed) + // In the worst-case this would store [[trainTestRatio]] of points + // buffered in memory. + val (trainPoints, testPoints) = labeledPoints.partition(_ => r.nextDouble() <= trainTestRatio) + val trainMatrix = new DMatrix(trainPoints, cacheFileName) + val testMatrix = new DMatrix(testPoints, cacheFileName) + r.setSeed(seed) + for (baseMargins <- baseMarginsOpt) { + val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio) + trainMatrix.setBaseMargin(trainMargin) + testMatrix.setBaseMargin(testMargin) + } + + // TODO: use group attribute from the points. + if (params.contains("groupData") && params("groupData") != null) { + trainMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]]( + TaskContext.getPartitionId()).toArray) + } + new Watches(train = trainMatrix, test = testMatrix) + } +} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index af14ce43c..676c4eb47 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -17,7 +17,7 @@ package ml.dmlc.xgboost4j.scala.spark.params import ml.dmlc.xgboost4j.scala.spark.TrackerConf -import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} + import org.apache.spark.ml.param._ trait GeneralParams extends Params { @@ -99,9 +99,12 @@ trait GeneralParams extends Params { */ val trackerConf = new TrackerConfParam(this, "tracker_conf", "Rabit tracker configurations") + /** Random seed for the C++ part of XGBoost and train/test splitting. */ + val seed = new LongParam(this, "seed", "random seed") + setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1, useExternalMemory -> false, silent -> 0, customObj -> null, customEval -> null, missing -> Float.NaN, - trackerConf -> TrackerConf() + trackerConf -> TrackerConf(), seed -> 0 ) } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 2981246f4..b86c0de0a 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark.params import scala.collection.immutable.HashSet -import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params} +import org.apache.spark.ml.param._ trait LearningTaskParams extends Params { @@ -70,8 +70,25 @@ trait LearningTaskParams extends Params { */ val weightCol = new Param[String](this, "weightCol", "weight column name") + /** + * Fraction of training points to use for testing. + */ + val trainTestRatio = new DoubleParam(this, "trainTestRatio", + "fraction of training points to use for testing", + ParamValidators.inRange(0, 1)) + + /** + * If non-zero, the training will be stopped after a specified number + * of consecutive increases in any evaluation metric. + */ + val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds", + "number of rounds of decreasing eval metric to tolerate before " + + "stopping the training", + (value: Int) => value == 0 || value > 1) + setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null, - baseMarginCol -> "baseMargin", weightCol -> "weight") + baseMarginCol -> "baseMargin", weightCol -> "weight", trainTestRatio -> 1.0, + numEarlyStoppingRounds -> 0) } private[spark] object LearningTaskParams { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala index d5ac77dbd..5954c5d53 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala @@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} + import org.apache.spark.ml.linalg.DenseVector import org.apache.spark.ml.param.ParamMap import org.apache.spark.sql._ @@ -201,7 +202,8 @@ class XGBoostDFSuite extends FunSuite with PerTest { val trainingDfWithMargin = trainingDf.withColumn("margin", functions.rand()) val testRDD = sc.parallelize(Classification.test.map(_.features)) val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "binary:logistic", "baseMarginCol" -> "margin") + "objective" -> "binary:logistic", "baseMarginCol" -> "margin", + "testTrainSplit" -> 0.5) def trainPredict(df: Dataset[_]): Array[Float] = { XGBoost.trainWithDataFrame(df, paramMap, round = 1, nWorkers = numWorkers) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 3b1476e54..40e00db7c 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -201,6 +201,12 @@ public class Booster implements Serializable, KryoSerializable { */ public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) throws XGBoostError { + // Hopefully, a tiny redundant allocation wouldn't hurt. + return evalSet(evalMatrixs, evalNames, eval, new float[evalNames.length]); + } + + public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval, + float[] metricsOut) throws XGBoostError { String evalInfo = ""; for (int i = 0; i < evalNames.length; i++) { String evalName = evalNames[i]; @@ -208,6 +214,7 @@ public class Booster implements Serializable, KryoSerializable { float evalResult = eval.eval(predict(evalMat), evalMat); String evalMetric = eval.getMetric(); evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult); + metricsOut[i] = evalResult; } return evalInfo; } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index b8601b19e..8021e878f 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -64,7 +64,7 @@ public class XGBoost { Map watches, IObjective obj, IEvaluation eval) throws XGBoostError { - return train(dtrain, params, round, watches, null, obj, eval); + return train(dtrain, params, round, watches, null, obj, eval, 0); } public static Booster train( @@ -74,7 +74,8 @@ public class XGBoost { Map watches, float[][] metrics, IObjective obj, - IEvaluation eval) throws XGBoostError { + IEvaluation eval, + int earlyStoppingRound) throws XGBoostError { //collect eval matrixs String[] evalNames; @@ -89,6 +90,7 @@ public class XGBoost { evalNames = names.toArray(new String[names.size()]); evalMats = mats.toArray(new DMatrix[mats.size()]); + metrics = metrics == null ? new float[evalNames.length][round] : metrics; //collect all data matrixs DMatrix[] allMats; @@ -120,19 +122,27 @@ public class XGBoost { //evaluation if (evalMats.length > 0) { + float[] metricsOut = new float[evalMats.length]; String evalInfo; if (eval != null) { - evalInfo = booster.evalSet(evalMats, evalNames, eval); + evalInfo = booster.evalSet(evalMats, evalNames, eval, metricsOut); } else { - if (metrics == null) { - evalInfo = booster.evalSet(evalMats, evalNames, iter); - } else { - float[] m = new float[evalMats.length]; - evalInfo = booster.evalSet(evalMats, evalNames, iter, m); - for (int i = 0; i < m.length; i++) { - metrics[i][iter] = m[i]; - } - } + evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut); + } + for (int i = 0; i < metricsOut.length; i++) { + metrics[i][iter] = metricsOut[i]; + } + + boolean decreasing = true; + float[] criterion = metrics[metrics.length - 1]; + for (int shift = 0; shift < Math.min(iter, earlyStoppingRound) - 1; shift++) { + decreasing &= criterion[iter - shift] <= criterion[iter - shift - 1]; + } + + if (!decreasing) { + Rabit.trackerPrint(String.format( + "early stopping after %d decreasing rounds", earlyStoppingRound)); + break; } if (Rabit.getRank() == 0) { Rabit.trackerPrint(evalInfo + '\n'); diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala index 48d57af17..fc05e899d 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala @@ -36,6 +36,9 @@ object XGBoost { * performance on the validation set. * @param metrics array containing the evaluation metrics for each matrix in watches for each * iteration + * @param earlyStoppingRound if non-zero, training would be stopped + * after a specified number of consecutive + * increases in any evaluation metric. * @param obj customized objective * @param eval customized evaluation * @return The trained booster. @@ -45,44 +48,20 @@ object XGBoost { dtrain: DMatrix, params: Map[String, Any], round: Int, - watches: Map[String, DMatrix], - metrics: Array[Array[Float]], - obj: ObjectiveTrait, - eval: EvalTrait): Booster = { - val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)} + watches: Map[String, DMatrix] = Map(), + metrics: Array[Array[Float]] = null, + obj: ObjectiveTrait = null, + eval: EvalTrait = null, + earlyStoppingRound: Int = 0): Booster = { + val jWatches = watches.mapValues(_.jDMatrix).asJava val xgboostInJava = JXGBoost.train( dtrain.jDMatrix, // we have to filter null value for customized obj and eval - params.filter(_._2 != null).map{ - case (key: String, value) => (key, value.toString) - }.toMap[String, AnyRef].asJava, - round, jWatches.asJava, metrics, obj, eval) + params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava, + round, jWatches, metrics, obj, eval, earlyStoppingRound) new Booster(xgboostInJava) } - /** - * Train a booster given parameters. - * - * @param dtrain Data to be trained. - * @param params Parameters. - * @param round Number of boosting iterations. - * @param watches a group of items to be evaluated during training, this allows user to watch - * performance on the validation set. - * @param obj customized objective - * @param eval customized evaluation - * @return The trained booster. - */ - @throws(classOf[XGBoostError]) - def train( - dtrain: DMatrix, - params: Map[String, Any], - round: Int, - watches: Map[String, DMatrix] = Map[String, DMatrix](), - obj: ObjectiveTrait = null, - eval: EvalTrait = null): Booster = { - train(dtrain, params, round, watches, null, obj, eval) - } - /** * Cross-validation with given parameters. * diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index b612e4306..0d790c9ce 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -23,11 +23,10 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.Map; import junit.framework.TestCase; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.junit.Test; /** @@ -37,16 +36,9 @@ import org.junit.Test; */ public class BoosterImplTest { public static class EvalError implements IEvaluation { - private static final Log logger = LogFactory.getLog(EvalError.class); - - String evalMetric = "custom_error"; - - public EvalError() { - } - @Override public String getMetric() { - return evalMetric; + return "custom_error"; } @Override @@ -56,8 +48,7 @@ public class BoosterImplTest { try { labels = dmat.getLabel(); } catch (XGBoostError ex) { - logger.error(ex); - return -1f; + throw new RuntimeException(ex); } int nrow = predicts.length; for (int i = 0; i < nrow; i++) { @@ -150,11 +141,55 @@ public class BoosterImplTest { TestCase.assertTrue("loadedPredictErr:" + loadedPredictError, loadedPredictError < 0.1f); } + private static class IncreasingEval implements IEvaluation { + private int value = 0; + + @Override + public String getMetric() { + return "inc"; + } + + @Override + public float eval(float[][] predicts, DMatrix dmat) { + return value++; + } + } + + @Test + public void testBoosterEarlyStop() throws XGBoostError, IOException { + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + // testBoosterWithFastHistogram(trainMat, testMat); + Map paramMap = new HashMap() { + { + put("max_depth", 3); + put("silent", 1); + put("objective", "binary:logistic"); + } + }; + Map watches = new LinkedHashMap<>(); + watches.put("training", trainMat); + watches.put("test", testMat); + + final int round = 10; + int earlyStoppingRound = 2; + float[][] metrics = new float[watches.size()][round]; + XGBoost.train(trainMat, paramMap, round, watches, metrics, null, new IncreasingEval(), + earlyStoppingRound); + + // Make sure we've stopped early. + for (int w = 0; w < watches.size(); w++) { + for (int r = earlyStoppingRound + 1; r < round; r++) { + TestCase.assertEquals(0.0f, metrics[w][r]); + } + } + } + private void testWithFastHisto(DMatrix trainingSet, Map watches, int round, Map paramMap, float threshold) throws XGBoostError { float[][] metrics = new float[watches.size()][round]; Booster booster = XGBoost.train(trainingSet, paramMap, round, watches, - metrics, null, null); + metrics, null, null, 0); for (int i = 0; i < metrics.length; i++) for (int j = 1; j < metrics[i].length; j++) { TestCase.assertTrue(metrics[i][j] >= metrics[i][j - 1]); diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala index fc4badc7b..2c3ce62a7 100644 --- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala +++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala @@ -74,7 +74,7 @@ class ScalaBoosterImplSuite extends FunSuite { val watches = List("train" -> trainMat, "test" -> testMat).toMap val round = 2 - XGBoost.train(trainMat, paramMap, round, watches, null, null) + XGBoost.train(trainMat, paramMap, round, watches) } private def trainBoosterWithFastHisto( @@ -84,7 +84,7 @@ class ScalaBoosterImplSuite extends FunSuite { paramMap: Map[String, String], threshold: Float): Booster = { val metrics = Array.fill(watches.size, round)(0.0f) - val booster = XGBoost.train(trainMat, paramMap, round, watches, metrics, null, null) + val booster = XGBoost.train(trainMat, paramMap, round, watches, metrics) for (i <- 0 until watches.size; j <- 1 until metrics(i).length) { assert(metrics(i)(j) >= metrics(i)(j - 1)) } @@ -143,7 +143,7 @@ class ScalaBoosterImplSuite extends FunSuite { "objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap val round = 2 val nfold = 5 - XGBoost.crossValidation(trainMat, params, round, nfold, null, null, null) + XGBoost.crossValidation(trainMat, params, round, nfold) } test("test with fast histo depthwise") {