[jvm-packages] Implemented early stopping (#2710)
* Allowed subsampling test from the training data frame/RDD The implementation requires storing 1 - trainTestRatio points in memory to make the sampling work. An alternative approach would be to construct the full DMatrix and then slice it deterministically into train/test. The peak memory consumption of such scenario, however, is twice the dataset size. * Removed duplication from 'XGBoost.train' Scala callers can (and should) use names to supply a subset of parameters. Method overloading is not required. * Reuse XGBoost seed parameter to stabilize train/test splitting * Added early stopping support to non-distributed XGBoost Closes #1544 * Added early-stopping to distributed XGBoost * Moved construction of 'watches' into a separate method This commit also fixes the handling of 'baseMargin' which previously was not added to the validation matrix. * Addressed review comments
This commit is contained in:
parent
74db9757b3
commit
69c3b78a29
@ -85,7 +85,7 @@ object BasicWalkThrough {
|
|||||||
val watches2 = new mutable.HashMap[String, DMatrix]
|
val watches2 = new mutable.HashMap[String, DMatrix]
|
||||||
watches2 += "train" -> trainMax2
|
watches2 += "train" -> trainMax2
|
||||||
watches2 += "test" -> testMax2
|
watches2 += "test" -> testMax2
|
||||||
val booster3 = XGBoost.train(trainMax2, params.toMap, round, watches2.toMap, null, null)
|
val booster3 = XGBoost.train(trainMax2, params.toMap, round, watches2.toMap)
|
||||||
val predicts3 = booster3.predict(testMax2)
|
val predicts3 = booster3.predict(testMax2)
|
||||||
println(checkPredicts(predicts, predicts3))
|
println(checkPredicts(predicts, predicts3))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -41,6 +41,6 @@ object CrossValidation {
|
|||||||
val metrics: Array[String] = null
|
val metrics: Array[String] = null
|
||||||
|
|
||||||
val evalHist: Array[String] =
|
val evalHist: Array[String] =
|
||||||
XGBoost.crossValidation(trainMat, params.toMap, round, nfold, metrics, null, null)
|
XGBoost.crossValidation(trainMat, params.toMap, round, nfold, metrics)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -151,7 +151,8 @@ object CustomObjective {
|
|||||||
val round = 2
|
val round = 2
|
||||||
// train a model
|
// train a model
|
||||||
val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
|
val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
|
||||||
XGBoost.train(trainMat, params.toMap, round, watches.toMap, new LogRegObj, new EvalError)
|
XGBoost.train(trainMat, params.toMap, round, watches.toMap,
|
||||||
|
obj = new LogRegObj, eval = new EvalError)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -54,6 +54,6 @@ object ExternalMemory {
|
|||||||
testMat.setBaseMargin(testPred)
|
testMat.setBaseMargin(testPred)
|
||||||
|
|
||||||
System.out.println("result of running from initial prediction")
|
System.out.println("result of running from initial prediction")
|
||||||
val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
|
val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -52,7 +52,7 @@ object GeneralizedLinearModel {
|
|||||||
watches += "test" -> testMat
|
watches += "test" -> testMat
|
||||||
|
|
||||||
val round = 4
|
val round = 4
|
||||||
val booster = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
|
val booster = XGBoost.train(trainMat, params.toMap, 1, watches.toMap)
|
||||||
val predicts = booster.predict(testMat)
|
val predicts = booster.predict(testMat)
|
||||||
val eval = new CustomEval
|
val eval = new CustomEval
|
||||||
println(s"error=${eval.eval(predicts, testMat)}")
|
println(s"error=${eval.eval(predicts, testMat)}")
|
||||||
|
|||||||
@ -55,7 +55,10 @@ object XGBoost {
|
|||||||
val trainMat = new DMatrix(dataIter, null)
|
val trainMat = new DMatrix(dataIter, null)
|
||||||
val watches = List("train" -> trainMat).toMap
|
val watches = List("train" -> trainMat).toMap
|
||||||
val round = 2
|
val round = 2
|
||||||
val booster = XGBoostScala.train(trainMat, paramMap, round, watches, null, null)
|
val numEarlyStoppingRounds = paramMap.get("numEarlyStoppingRounds")
|
||||||
|
.map(_.toString.toInt).getOrElse(0)
|
||||||
|
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
|
||||||
|
earlyStoppingRound = numEarlyStoppingRounds)
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
collector.collect(new XGBoostModel(booster))
|
collector.collect(new XGBoostModel(booster))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
@ -25,9 +26,9 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
|||||||
|
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
||||||
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.Dataset
|
import org.apache.spark.sql.Dataset
|
||||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
|
||||||
import org.apache.spark.{SparkContext, TaskContext}
|
import org.apache.spark.{SparkContext, TaskContext}
|
||||||
|
|
||||||
object TrackerConf {
|
object TrackerConf {
|
||||||
@ -94,7 +95,7 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def buildDistributedBoosters(
|
private[spark] def buildDistributedBoosters(
|
||||||
trainingSet: RDD[XGBLabeledPoint],
|
data: RDD[XGBLabeledPoint],
|
||||||
params: Map[String, Any],
|
params: Map[String, Any],
|
||||||
rabitEnv: java.util.Map[String, String],
|
rabitEnv: java.util.Map[String, String],
|
||||||
numWorkers: Int,
|
numWorkers: Int,
|
||||||
@ -103,19 +104,19 @@ object XGBoost extends Serializable {
|
|||||||
eval: EvalTrait,
|
eval: EvalTrait,
|
||||||
useExternalMemory: Boolean,
|
useExternalMemory: Boolean,
|
||||||
missing: Float): RDD[Booster] = {
|
missing: Float): RDD[Booster] = {
|
||||||
val partitionedTrainingSet = if (trainingSet.getNumPartitions != numWorkers) {
|
val partitionedData = if (data.getNumPartitions != numWorkers) {
|
||||||
logger.info(s"repartitioning training set to $numWorkers partitions")
|
logger.info(s"repartitioning training set to $numWorkers partitions")
|
||||||
trainingSet.repartition(numWorkers)
|
data.repartition(numWorkers)
|
||||||
} else {
|
} else {
|
||||||
trainingSet
|
data
|
||||||
}
|
}
|
||||||
val partitionedBaseMargin = partitionedTrainingSet.map(_.baseMargin)
|
val partitionedBaseMargin = partitionedData.map(_.baseMargin)
|
||||||
val appName = partitionedTrainingSet.context.appName
|
val appName = partitionedData.context.appName
|
||||||
// to workaround the empty partitions in training dataset,
|
// to workaround the empty partitions in training dataset,
|
||||||
// this might not be the best efficient implementation, see
|
// this might not be the best efficient implementation, see
|
||||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||||
partitionedTrainingSet.zipPartitions(partitionedBaseMargin) { (trainingPoints, baseMargins) =>
|
partitionedData.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) =>
|
||||||
if (trainingPoints.isEmpty) {
|
if (labeledPoints.isEmpty) {
|
||||||
throw new XGBoostError(
|
throw new XGBoostError(
|
||||||
s"detected an empty partition in the training data, partition ID:" +
|
s"detected an empty partition in the training data, partition ID:" +
|
||||||
s" ${TaskContext.getPartitionId()}")
|
s" ${TaskContext.getPartitionId()}")
|
||||||
@ -128,21 +129,20 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||||
Rabit.init(rabitEnv)
|
Rabit.init(rabitEnv)
|
||||||
val trainingMatrix = new DMatrix(
|
val watches = Watches(params,
|
||||||
fromDenseToSparseLabeledPoints(trainingPoints, missing), cacheFileName)
|
fromDenseToSparseLabeledPoints(labeledPoints, missing),
|
||||||
|
fromBaseMarginsToArray(baseMargins), cacheFileName)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// TODO: use group attribute from the points.
|
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
|
||||||
if (params.contains("groupData") && params("groupData") != null) {
|
.map(_.toString.toInt).getOrElse(0)
|
||||||
trainingMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
|
val booster = SXGBoost.train(watches.train, params, round,
|
||||||
TaskContext.getPartitionId()).toArray)
|
watches = watches.toMap, obj = obj, eval = eval,
|
||||||
}
|
earlyStoppingRound = numEarlyStoppingRounds)
|
||||||
fromBaseMarginsToArray(baseMargins).foreach(trainingMatrix.setBaseMargin)
|
|
||||||
val booster = SXGBoost.train(trainingMatrix, params, round,
|
|
||||||
watches = Map("train" -> trainingMatrix), obj, eval)
|
|
||||||
Iterator(booster)
|
Iterator(booster)
|
||||||
} finally {
|
} finally {
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
trainingMatrix.delete()
|
watches.delete()
|
||||||
}
|
}
|
||||||
}.cache()
|
}.cache()
|
||||||
}
|
}
|
||||||
@ -417,3 +417,46 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private class Watches private(val train: DMatrix, val test: DMatrix) {
|
||||||
|
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
||||||
|
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
||||||
|
|
||||||
|
def size: Int = toMap.size
|
||||||
|
|
||||||
|
def delete(): Unit = {
|
||||||
|
toMap.values.foreach(_.delete())
|
||||||
|
}
|
||||||
|
|
||||||
|
override def toString: String = toMap.toString
|
||||||
|
}
|
||||||
|
|
||||||
|
private object Watches {
|
||||||
|
def apply(
|
||||||
|
params: Map[String, Any],
|
||||||
|
labeledPoints: Iterator[XGBLabeledPoint],
|
||||||
|
baseMarginsOpt: Option[Array[Float]],
|
||||||
|
cacheFileName: String): Watches = {
|
||||||
|
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
|
||||||
|
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
||||||
|
val r = new Random(seed)
|
||||||
|
// In the worst-case this would store [[trainTestRatio]] of points
|
||||||
|
// buffered in memory.
|
||||||
|
val (trainPoints, testPoints) = labeledPoints.partition(_ => r.nextDouble() <= trainTestRatio)
|
||||||
|
val trainMatrix = new DMatrix(trainPoints, cacheFileName)
|
||||||
|
val testMatrix = new DMatrix(testPoints, cacheFileName)
|
||||||
|
r.setSeed(seed)
|
||||||
|
for (baseMargins <- baseMarginsOpt) {
|
||||||
|
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
|
||||||
|
trainMatrix.setBaseMargin(trainMargin)
|
||||||
|
testMatrix.setBaseMargin(testMargin)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: use group attribute from the points.
|
||||||
|
if (params.contains("groupData") && params("groupData") != null) {
|
||||||
|
trainMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
|
||||||
|
TaskContext.getPartitionId()).toArray)
|
||||||
|
}
|
||||||
|
new Watches(train = trainMatrix, test = testMatrix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -17,7 +17,7 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark.params
|
package ml.dmlc.xgboost4j.scala.spark.params
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
|
|
||||||
trait GeneralParams extends Params {
|
trait GeneralParams extends Params {
|
||||||
@ -99,9 +99,12 @@ trait GeneralParams extends Params {
|
|||||||
*/
|
*/
|
||||||
val trackerConf = new TrackerConfParam(this, "tracker_conf", "Rabit tracker configurations")
|
val trackerConf = new TrackerConfParam(this, "tracker_conf", "Rabit tracker configurations")
|
||||||
|
|
||||||
|
/** Random seed for the C++ part of XGBoost and train/test splitting. */
|
||||||
|
val seed = new LongParam(this, "seed", "random seed")
|
||||||
|
|
||||||
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
||||||
useExternalMemory -> false, silent -> 0,
|
useExternalMemory -> false, silent -> 0,
|
||||||
customObj -> null, customEval -> null, missing -> Float.NaN,
|
customObj -> null, customEval -> null, missing -> Float.NaN,
|
||||||
trackerConf -> TrackerConf()
|
trackerConf -> TrackerConf(), seed -> 0
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
|||||||
|
|
||||||
import scala.collection.immutable.HashSet
|
import scala.collection.immutable.HashSet
|
||||||
|
|
||||||
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
|
import org.apache.spark.ml.param._
|
||||||
|
|
||||||
trait LearningTaskParams extends Params {
|
trait LearningTaskParams extends Params {
|
||||||
|
|
||||||
@ -70,8 +70,25 @@ trait LearningTaskParams extends Params {
|
|||||||
*/
|
*/
|
||||||
val weightCol = new Param[String](this, "weightCol", "weight column name")
|
val weightCol = new Param[String](this, "weightCol", "weight column name")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fraction of training points to use for testing.
|
||||||
|
*/
|
||||||
|
val trainTestRatio = new DoubleParam(this, "trainTestRatio",
|
||||||
|
"fraction of training points to use for testing",
|
||||||
|
ParamValidators.inRange(0, 1))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If non-zero, the training will be stopped after a specified number
|
||||||
|
* of consecutive increases in any evaluation metric.
|
||||||
|
*/
|
||||||
|
val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds",
|
||||||
|
"number of rounds of decreasing eval metric to tolerate before " +
|
||||||
|
"stopping the training",
|
||||||
|
(value: Int) => value == 0 || value > 1)
|
||||||
|
|
||||||
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null,
|
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null,
|
||||||
baseMarginCol -> "baseMargin", weightCol -> "weight")
|
baseMarginCol -> "baseMargin", weightCol -> "weight", trainTestRatio -> 1.0,
|
||||||
|
numEarlyStoppingRounds -> 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] object LearningTaskParams {
|
private[spark] object LearningTaskParams {
|
||||||
|
|||||||
@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
|
|
||||||
import org.apache.spark.ml.linalg.DenseVector
|
import org.apache.spark.ml.linalg.DenseVector
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
@ -201,7 +202,8 @@ class XGBoostDFSuite extends FunSuite with PerTest {
|
|||||||
val trainingDfWithMargin = trainingDf.withColumn("margin", functions.rand())
|
val trainingDfWithMargin = trainingDf.withColumn("margin", functions.rand())
|
||||||
val testRDD = sc.parallelize(Classification.test.map(_.features))
|
val testRDD = sc.parallelize(Classification.test.map(_.features))
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic", "baseMarginCol" -> "margin")
|
"objective" -> "binary:logistic", "baseMarginCol" -> "margin",
|
||||||
|
"testTrainSplit" -> 0.5)
|
||||||
|
|
||||||
def trainPredict(df: Dataset[_]): Array[Float] = {
|
def trainPredict(df: Dataset[_]): Array[Float] = {
|
||||||
XGBoost.trainWithDataFrame(df, paramMap, round = 1, nWorkers = numWorkers)
|
XGBoost.trainWithDataFrame(df, paramMap, round = 1, nWorkers = numWorkers)
|
||||||
|
|||||||
@ -201,6 +201,12 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
*/
|
*/
|
||||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
|
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
|
||||||
throws XGBoostError {
|
throws XGBoostError {
|
||||||
|
// Hopefully, a tiny redundant allocation wouldn't hurt.
|
||||||
|
return evalSet(evalMatrixs, evalNames, eval, new float[evalNames.length]);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval,
|
||||||
|
float[] metricsOut) throws XGBoostError {
|
||||||
String evalInfo = "";
|
String evalInfo = "";
|
||||||
for (int i = 0; i < evalNames.length; i++) {
|
for (int i = 0; i < evalNames.length; i++) {
|
||||||
String evalName = evalNames[i];
|
String evalName = evalNames[i];
|
||||||
@ -208,6 +214,7 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
float evalResult = eval.eval(predict(evalMat), evalMat);
|
float evalResult = eval.eval(predict(evalMat), evalMat);
|
||||||
String evalMetric = eval.getMetric();
|
String evalMetric = eval.getMetric();
|
||||||
evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
|
evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
|
||||||
|
metricsOut[i] = evalResult;
|
||||||
}
|
}
|
||||||
return evalInfo;
|
return evalInfo;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -64,7 +64,7 @@ public class XGBoost {
|
|||||||
Map<String, DMatrix> watches,
|
Map<String, DMatrix> watches,
|
||||||
IObjective obj,
|
IObjective obj,
|
||||||
IEvaluation eval) throws XGBoostError {
|
IEvaluation eval) throws XGBoostError {
|
||||||
return train(dtrain, params, round, watches, null, obj, eval);
|
return train(dtrain, params, round, watches, null, obj, eval, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Booster train(
|
public static Booster train(
|
||||||
@ -74,7 +74,8 @@ public class XGBoost {
|
|||||||
Map<String, DMatrix> watches,
|
Map<String, DMatrix> watches,
|
||||||
float[][] metrics,
|
float[][] metrics,
|
||||||
IObjective obj,
|
IObjective obj,
|
||||||
IEvaluation eval) throws XGBoostError {
|
IEvaluation eval,
|
||||||
|
int earlyStoppingRound) throws XGBoostError {
|
||||||
|
|
||||||
//collect eval matrixs
|
//collect eval matrixs
|
||||||
String[] evalNames;
|
String[] evalNames;
|
||||||
@ -89,6 +90,7 @@ public class XGBoost {
|
|||||||
|
|
||||||
evalNames = names.toArray(new String[names.size()]);
|
evalNames = names.toArray(new String[names.size()]);
|
||||||
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
||||||
|
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
|
||||||
|
|
||||||
//collect all data matrixs
|
//collect all data matrixs
|
||||||
DMatrix[] allMats;
|
DMatrix[] allMats;
|
||||||
@ -120,19 +122,27 @@ public class XGBoost {
|
|||||||
|
|
||||||
//evaluation
|
//evaluation
|
||||||
if (evalMats.length > 0) {
|
if (evalMats.length > 0) {
|
||||||
|
float[] metricsOut = new float[evalMats.length];
|
||||||
String evalInfo;
|
String evalInfo;
|
||||||
if (eval != null) {
|
if (eval != null) {
|
||||||
evalInfo = booster.evalSet(evalMats, evalNames, eval);
|
evalInfo = booster.evalSet(evalMats, evalNames, eval, metricsOut);
|
||||||
} else {
|
} else {
|
||||||
if (metrics == null) {
|
evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut);
|
||||||
evalInfo = booster.evalSet(evalMats, evalNames, iter);
|
}
|
||||||
} else {
|
for (int i = 0; i < metricsOut.length; i++) {
|
||||||
float[] m = new float[evalMats.length];
|
metrics[i][iter] = metricsOut[i];
|
||||||
evalInfo = booster.evalSet(evalMats, evalNames, iter, m);
|
}
|
||||||
for (int i = 0; i < m.length; i++) {
|
|
||||||
metrics[i][iter] = m[i];
|
boolean decreasing = true;
|
||||||
}
|
float[] criterion = metrics[metrics.length - 1];
|
||||||
}
|
for (int shift = 0; shift < Math.min(iter, earlyStoppingRound) - 1; shift++) {
|
||||||
|
decreasing &= criterion[iter - shift] <= criterion[iter - shift - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!decreasing) {
|
||||||
|
Rabit.trackerPrint(String.format(
|
||||||
|
"early stopping after %d decreasing rounds", earlyStoppingRound));
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
if (Rabit.getRank() == 0) {
|
if (Rabit.getRank() == 0) {
|
||||||
Rabit.trackerPrint(evalInfo + '\n');
|
Rabit.trackerPrint(evalInfo + '\n');
|
||||||
|
|||||||
@ -36,6 +36,9 @@ object XGBoost {
|
|||||||
* performance on the validation set.
|
* performance on the validation set.
|
||||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||||
* iteration
|
* iteration
|
||||||
|
* @param earlyStoppingRound if non-zero, training would be stopped
|
||||||
|
* after a specified number of consecutive
|
||||||
|
* increases in any evaluation metric.
|
||||||
* @param obj customized objective
|
* @param obj customized objective
|
||||||
* @param eval customized evaluation
|
* @param eval customized evaluation
|
||||||
* @return The trained booster.
|
* @return The trained booster.
|
||||||
@ -45,44 +48,20 @@ object XGBoost {
|
|||||||
dtrain: DMatrix,
|
dtrain: DMatrix,
|
||||||
params: Map[String, Any],
|
params: Map[String, Any],
|
||||||
round: Int,
|
round: Int,
|
||||||
watches: Map[String, DMatrix],
|
watches: Map[String, DMatrix] = Map(),
|
||||||
metrics: Array[Array[Float]],
|
metrics: Array[Array[Float]] = null,
|
||||||
obj: ObjectiveTrait,
|
obj: ObjectiveTrait = null,
|
||||||
eval: EvalTrait): Booster = {
|
eval: EvalTrait = null,
|
||||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
earlyStoppingRound: Int = 0): Booster = {
|
||||||
|
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||||
val xgboostInJava = JXGBoost.train(
|
val xgboostInJava = JXGBoost.train(
|
||||||
dtrain.jDMatrix,
|
dtrain.jDMatrix,
|
||||||
// we have to filter null value for customized obj and eval
|
// we have to filter null value for customized obj and eval
|
||||||
params.filter(_._2 != null).map{
|
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||||
case (key: String, value) => (key, value.toString)
|
round, jWatches, metrics, obj, eval, earlyStoppingRound)
|
||||||
}.toMap[String, AnyRef].asJava,
|
|
||||||
round, jWatches.asJava, metrics, obj, eval)
|
|
||||||
new Booster(xgboostInJava)
|
new Booster(xgboostInJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Train a booster given parameters.
|
|
||||||
*
|
|
||||||
* @param dtrain Data to be trained.
|
|
||||||
* @param params Parameters.
|
|
||||||
* @param round Number of boosting iterations.
|
|
||||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
|
||||||
* performance on the validation set.
|
|
||||||
* @param obj customized objective
|
|
||||||
* @param eval customized evaluation
|
|
||||||
* @return The trained booster.
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def train(
|
|
||||||
dtrain: DMatrix,
|
|
||||||
params: Map[String, Any],
|
|
||||||
round: Int,
|
|
||||||
watches: Map[String, DMatrix] = Map[String, DMatrix](),
|
|
||||||
obj: ObjectiveTrait = null,
|
|
||||||
eval: EvalTrait = null): Booster = {
|
|
||||||
train(dtrain, params, round, watches, null, obj, eval)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cross-validation with given parameters.
|
* Cross-validation with given parameters.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -23,11 +23,10 @@ import java.nio.file.Files;
|
|||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import junit.framework.TestCase;
|
import junit.framework.TestCase;
|
||||||
import org.apache.commons.logging.Log;
|
|
||||||
import org.apache.commons.logging.LogFactory;
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -37,16 +36,9 @@ import org.junit.Test;
|
|||||||
*/
|
*/
|
||||||
public class BoosterImplTest {
|
public class BoosterImplTest {
|
||||||
public static class EvalError implements IEvaluation {
|
public static class EvalError implements IEvaluation {
|
||||||
private static final Log logger = LogFactory.getLog(EvalError.class);
|
|
||||||
|
|
||||||
String evalMetric = "custom_error";
|
|
||||||
|
|
||||||
public EvalError() {
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getMetric() {
|
public String getMetric() {
|
||||||
return evalMetric;
|
return "custom_error";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -56,8 +48,7 @@ public class BoosterImplTest {
|
|||||||
try {
|
try {
|
||||||
labels = dmat.getLabel();
|
labels = dmat.getLabel();
|
||||||
} catch (XGBoostError ex) {
|
} catch (XGBoostError ex) {
|
||||||
logger.error(ex);
|
throw new RuntimeException(ex);
|
||||||
return -1f;
|
|
||||||
}
|
}
|
||||||
int nrow = predicts.length;
|
int nrow = predicts.length;
|
||||||
for (int i = 0; i < nrow; i++) {
|
for (int i = 0; i < nrow; i++) {
|
||||||
@ -150,11 +141,55 @@ public class BoosterImplTest {
|
|||||||
TestCase.assertTrue("loadedPredictErr:" + loadedPredictError, loadedPredictError < 0.1f);
|
TestCase.assertTrue("loadedPredictErr:" + loadedPredictError, loadedPredictError < 0.1f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class IncreasingEval implements IEvaluation {
|
||||||
|
private int value = 0;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMetric() {
|
||||||
|
return "inc";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float eval(float[][] predicts, DMatrix dmat) {
|
||||||
|
return value++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBoosterEarlyStop() throws XGBoostError, IOException {
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
// testBoosterWithFastHistogram(trainMat, testMat);
|
||||||
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
{
|
||||||
|
put("max_depth", 3);
|
||||||
|
put("silent", 1);
|
||||||
|
put("objective", "binary:logistic");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Map<String, DMatrix> watches = new LinkedHashMap<>();
|
||||||
|
watches.put("training", trainMat);
|
||||||
|
watches.put("test", testMat);
|
||||||
|
|
||||||
|
final int round = 10;
|
||||||
|
int earlyStoppingRound = 2;
|
||||||
|
float[][] metrics = new float[watches.size()][round];
|
||||||
|
XGBoost.train(trainMat, paramMap, round, watches, metrics, null, new IncreasingEval(),
|
||||||
|
earlyStoppingRound);
|
||||||
|
|
||||||
|
// Make sure we've stopped early.
|
||||||
|
for (int w = 0; w < watches.size(); w++) {
|
||||||
|
for (int r = earlyStoppingRound + 1; r < round; r++) {
|
||||||
|
TestCase.assertEquals(0.0f, metrics[w][r]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void testWithFastHisto(DMatrix trainingSet, Map<String, DMatrix> watches, int round,
|
private void testWithFastHisto(DMatrix trainingSet, Map<String, DMatrix> watches, int round,
|
||||||
Map<String, Object> paramMap, float threshold) throws XGBoostError {
|
Map<String, Object> paramMap, float threshold) throws XGBoostError {
|
||||||
float[][] metrics = new float[watches.size()][round];
|
float[][] metrics = new float[watches.size()][round];
|
||||||
Booster booster = XGBoost.train(trainingSet, paramMap, round, watches,
|
Booster booster = XGBoost.train(trainingSet, paramMap, round, watches,
|
||||||
metrics, null, null);
|
metrics, null, null, 0);
|
||||||
for (int i = 0; i < metrics.length; i++)
|
for (int i = 0; i < metrics.length; i++)
|
||||||
for (int j = 1; j < metrics[i].length; j++) {
|
for (int j = 1; j < metrics[i].length; j++) {
|
||||||
TestCase.assertTrue(metrics[i][j] >= metrics[i][j - 1]);
|
TestCase.assertTrue(metrics[i][j] >= metrics[i][j - 1]);
|
||||||
|
|||||||
@ -74,7 +74,7 @@ class ScalaBoosterImplSuite extends FunSuite {
|
|||||||
val watches = List("train" -> trainMat, "test" -> testMat).toMap
|
val watches = List("train" -> trainMat, "test" -> testMat).toMap
|
||||||
|
|
||||||
val round = 2
|
val round = 2
|
||||||
XGBoost.train(trainMat, paramMap, round, watches, null, null)
|
XGBoost.train(trainMat, paramMap, round, watches)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def trainBoosterWithFastHisto(
|
private def trainBoosterWithFastHisto(
|
||||||
@ -84,7 +84,7 @@ class ScalaBoosterImplSuite extends FunSuite {
|
|||||||
paramMap: Map[String, String],
|
paramMap: Map[String, String],
|
||||||
threshold: Float): Booster = {
|
threshold: Float): Booster = {
|
||||||
val metrics = Array.fill(watches.size, round)(0.0f)
|
val metrics = Array.fill(watches.size, round)(0.0f)
|
||||||
val booster = XGBoost.train(trainMat, paramMap, round, watches, metrics, null, null)
|
val booster = XGBoost.train(trainMat, paramMap, round, watches, metrics)
|
||||||
for (i <- 0 until watches.size; j <- 1 until metrics(i).length) {
|
for (i <- 0 until watches.size; j <- 1 until metrics(i).length) {
|
||||||
assert(metrics(i)(j) >= metrics(i)(j - 1))
|
assert(metrics(i)(j) >= metrics(i)(j - 1))
|
||||||
}
|
}
|
||||||
@ -143,7 +143,7 @@ class ScalaBoosterImplSuite extends FunSuite {
|
|||||||
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
||||||
val round = 2
|
val round = 2
|
||||||
val nfold = 5
|
val nfold = 5
|
||||||
XGBoost.crossValidation(trainMat, params, round, nfold, null, null, null)
|
XGBoost.crossValidation(trainMat, params, round, nfold)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test with fast histo depthwise") {
|
test("test with fast histo depthwise") {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user