[jvm-packages] Implemented early stopping (#2710)
* Allowed subsampling test from the training data frame/RDD The implementation requires storing 1 - trainTestRatio points in memory to make the sampling work. An alternative approach would be to construct the full DMatrix and then slice it deterministically into train/test. The peak memory consumption of such scenario, however, is twice the dataset size. * Removed duplication from 'XGBoost.train' Scala callers can (and should) use names to supply a subset of parameters. Method overloading is not required. * Reuse XGBoost seed parameter to stabilize train/test splitting * Added early stopping support to non-distributed XGBoost Closes #1544 * Added early-stopping to distributed XGBoost * Moved construction of 'watches' into a separate method This commit also fixes the handling of 'baseMargin' which previously was not added to the validation matrix. * Addressed review comments
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
@@ -25,9 +26,9 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
|
||||
object TrackerConf {
|
||||
@@ -94,7 +95,7 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
|
||||
private[spark] def buildDistributedBoosters(
|
||||
trainingSet: RDD[XGBLabeledPoint],
|
||||
data: RDD[XGBLabeledPoint],
|
||||
params: Map[String, Any],
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
numWorkers: Int,
|
||||
@@ -103,19 +104,19 @@ object XGBoost extends Serializable {
|
||||
eval: EvalTrait,
|
||||
useExternalMemory: Boolean,
|
||||
missing: Float): RDD[Booster] = {
|
||||
val partitionedTrainingSet = if (trainingSet.getNumPartitions != numWorkers) {
|
||||
val partitionedData = if (data.getNumPartitions != numWorkers) {
|
||||
logger.info(s"repartitioning training set to $numWorkers partitions")
|
||||
trainingSet.repartition(numWorkers)
|
||||
data.repartition(numWorkers)
|
||||
} else {
|
||||
trainingSet
|
||||
data
|
||||
}
|
||||
val partitionedBaseMargin = partitionedTrainingSet.map(_.baseMargin)
|
||||
val appName = partitionedTrainingSet.context.appName
|
||||
val partitionedBaseMargin = partitionedData.map(_.baseMargin)
|
||||
val appName = partitionedData.context.appName
|
||||
// to workaround the empty partitions in training dataset,
|
||||
// this might not be the best efficient implementation, see
|
||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||
partitionedTrainingSet.zipPartitions(partitionedBaseMargin) { (trainingPoints, baseMargins) =>
|
||||
if (trainingPoints.isEmpty) {
|
||||
partitionedData.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) =>
|
||||
if (labeledPoints.isEmpty) {
|
||||
throw new XGBoostError(
|
||||
s"detected an empty partition in the training data, partition ID:" +
|
||||
s" ${TaskContext.getPartitionId()}")
|
||||
@@ -128,21 +129,20 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv)
|
||||
val trainingMatrix = new DMatrix(
|
||||
fromDenseToSparseLabeledPoints(trainingPoints, missing), cacheFileName)
|
||||
val watches = Watches(params,
|
||||
fromDenseToSparseLabeledPoints(labeledPoints, missing),
|
||||
fromBaseMarginsToArray(baseMargins), cacheFileName)
|
||||
|
||||
try {
|
||||
// TODO: use group attribute from the points.
|
||||
if (params.contains("groupData") && params("groupData") != null) {
|
||||
trainingMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
|
||||
TaskContext.getPartitionId()).toArray)
|
||||
}
|
||||
fromBaseMarginsToArray(baseMargins).foreach(trainingMatrix.setBaseMargin)
|
||||
val booster = SXGBoost.train(trainingMatrix, params, round,
|
||||
watches = Map("train" -> trainingMatrix), obj, eval)
|
||||
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val booster = SXGBoost.train(watches.train, params, round,
|
||||
watches = watches.toMap, obj = obj, eval = eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds)
|
||||
Iterator(booster)
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
trainingMatrix.delete()
|
||||
watches.delete()
|
||||
}
|
||||
}.cache()
|
||||
}
|
||||
@@ -417,3 +417,46 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class Watches private(val train: DMatrix, val test: DMatrix) {
|
||||
def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
|
||||
.filter { case (_, matrix) => matrix.rowNum > 0 }
|
||||
|
||||
def size: Int = toMap.size
|
||||
|
||||
def delete(): Unit = {
|
||||
toMap.values.foreach(_.delete())
|
||||
}
|
||||
|
||||
override def toString: String = toMap.toString
|
||||
}
|
||||
|
||||
private object Watches {
|
||||
def apply(
|
||||
params: Map[String, Any],
|
||||
labeledPoints: Iterator[XGBLabeledPoint],
|
||||
baseMarginsOpt: Option[Array[Float]],
|
||||
cacheFileName: String): Watches = {
|
||||
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
|
||||
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
||||
val r = new Random(seed)
|
||||
// In the worst-case this would store [[trainTestRatio]] of points
|
||||
// buffered in memory.
|
||||
val (trainPoints, testPoints) = labeledPoints.partition(_ => r.nextDouble() <= trainTestRatio)
|
||||
val trainMatrix = new DMatrix(trainPoints, cacheFileName)
|
||||
val testMatrix = new DMatrix(testPoints, cacheFileName)
|
||||
r.setSeed(seed)
|
||||
for (baseMargins <- baseMarginsOpt) {
|
||||
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
|
||||
trainMatrix.setBaseMargin(trainMargin)
|
||||
testMatrix.setBaseMargin(testMargin)
|
||||
}
|
||||
|
||||
// TODO: use group attribute from the points.
|
||||
if (params.contains("groupData") && params("groupData") != null) {
|
||||
trainMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
|
||||
TaskContext.getPartitionId()).toArray)
|
||||
}
|
||||
new Watches(train = trainMatrix, test = testMatrix)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
|
||||
import org.apache.spark.ml.param._
|
||||
|
||||
trait GeneralParams extends Params {
|
||||
@@ -99,9 +99,12 @@ trait GeneralParams extends Params {
|
||||
*/
|
||||
val trackerConf = new TrackerConfParam(this, "tracker_conf", "Rabit tracker configurations")
|
||||
|
||||
/** Random seed for the C++ part of XGBoost and train/test splitting. */
|
||||
val seed = new LongParam(this, "seed", "random seed")
|
||||
|
||||
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
||||
useExternalMemory -> false, silent -> 0,
|
||||
customObj -> null, customEval -> null, missing -> Float.NaN,
|
||||
trackerConf -> TrackerConf()
|
||||
trackerConf -> TrackerConf(), seed -> 0
|
||||
)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import scala.collection.immutable.HashSet
|
||||
|
||||
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
|
||||
import org.apache.spark.ml.param._
|
||||
|
||||
trait LearningTaskParams extends Params {
|
||||
|
||||
@@ -70,8 +70,25 @@ trait LearningTaskParams extends Params {
|
||||
*/
|
||||
val weightCol = new Param[String](this, "weightCol", "weight column name")
|
||||
|
||||
/**
|
||||
* Fraction of training points to use for testing.
|
||||
*/
|
||||
val trainTestRatio = new DoubleParam(this, "trainTestRatio",
|
||||
"fraction of training points to use for testing",
|
||||
ParamValidators.inRange(0, 1))
|
||||
|
||||
/**
|
||||
* If non-zero, the training will be stopped after a specified number
|
||||
* of consecutive increases in any evaluation metric.
|
||||
*/
|
||||
val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds",
|
||||
"number of rounds of decreasing eval metric to tolerate before " +
|
||||
"stopping the training",
|
||||
(value: Int) => value == 0 || value > 1)
|
||||
|
||||
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null,
|
||||
baseMarginCol -> "baseMargin", weightCol -> "weight")
|
||||
baseMarginCol -> "baseMargin", weightCol -> "weight", trainTestRatio -> 1.0,
|
||||
numEarlyStoppingRounds -> 0)
|
||||
}
|
||||
|
||||
private[spark] object LearningTaskParams {
|
||||
|
||||
@@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.sql._
|
||||
@@ -201,7 +202,8 @@ class XGBoostDFSuite extends FunSuite with PerTest {
|
||||
val trainingDfWithMargin = trainingDf.withColumn("margin", functions.rand())
|
||||
val testRDD = sc.parallelize(Classification.test.map(_.features))
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "baseMarginCol" -> "margin")
|
||||
"objective" -> "binary:logistic", "baseMarginCol" -> "margin",
|
||||
"testTrainSplit" -> 0.5)
|
||||
|
||||
def trainPredict(df: Dataset[_]): Array[Float] = {
|
||||
XGBoost.trainWithDataFrame(df, paramMap, round = 1, nWorkers = numWorkers)
|
||||
|
||||
Reference in New Issue
Block a user