[jvm-packages] add configuration flag to control whether to cache transformed training set (#4268)

* control whether to cache data

* uncache
This commit is contained in:
Nan Zhu 2019-03-18 10:13:28 +08:00 committed by GitHub
parent 29a1356669
commit 359ed9c5bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 42 deletions

View File

@ -33,6 +33,7 @@ import org.apache.commons.logging.LogFactory
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext} import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.storage.StorageLevel
/** /**
@ -305,9 +306,8 @@ object XGBoost extends Serializable {
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = { evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) = val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
parameterFetchAndValidation(params, trainingData.sparkContext) parameterFetchAndValidation(params, trainingData.sparkContext)
val partitionedData = repartitionForTraining(trainingData, nWorkers)
if (evalSetsMap.isEmpty) { if (evalSetsMap.isEmpty) {
partitionedData.mapPartitions(labeledPoints => { trainingData.mapPartitions(labeledPoints => {
val watches = Watches.buildWatches(params, val watches = Watches.buildWatches(params,
removeMissingValues(labeledPoints, missing), removeMissingValues(labeledPoints, missing),
getCacheDirName(useExternalMemory)) getCacheDirName(useExternalMemory))
@ -315,7 +315,7 @@ object XGBoost extends Serializable {
obj, eval, prevBooster) obj, eval, prevBooster)
}).cache() }).cache()
} else { } else {
coPartitionNoGroupSets(partitionedData, evalSetsMap, nWorkers).mapPartitions { coPartitionNoGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions {
nameAndLabeledPointSets => nameAndLabeledPointSets =>
val watches = Watches.buildWatches( val watches = Watches.buildWatches(
nameAndLabeledPointSets.map { nameAndLabeledPointSets.map {
@ -328,7 +328,7 @@ object XGBoost extends Serializable {
} }
private def trainForRanking( private def trainForRanking(
trainingData: RDD[XGBLabeledPoint], trainingData: RDD[Array[XGBLabeledPoint]],
params: Map[String, Any], params: Map[String, Any],
rabitEnv: java.util.Map[String, String], rabitEnv: java.util.Map[String, String],
checkpointRound: Int, checkpointRound: Int,
@ -336,16 +336,15 @@ object XGBoost extends Serializable {
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = { evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) = val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
parameterFetchAndValidation(params, trainingData.sparkContext) parameterFetchAndValidation(params, trainingData.sparkContext)
val partitionedTrainingSet = repartitionForTrainingGroup(trainingData, nWorkers)
if (evalSetsMap.isEmpty) { if (evalSetsMap.isEmpty) {
partitionedTrainingSet.mapPartitions(labeledPointGroups => { trainingData.mapPartitions(labeledPointGroups => {
val watches = Watches.buildWatchesWithGroup(params, val watches = Watches.buildWatchesWithGroup(params,
removeMissingValuesWithGroup(labeledPointGroups, missing), removeMissingValuesWithGroup(labeledPointGroups, missing),
getCacheDirName(useExternalMemory)) getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster) buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
}).cache() }).cache()
} else { } else {
coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, nWorkers).mapPartitions( coPartitionGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions(
labeledPointGroupSets => { labeledPointGroupSets => {
val watches = Watches.buildWatchesWithGroup( val watches = Watches.buildWatchesWithGroup(
labeledPointGroupSets.map { labeledPointGroupSets.map {
@ -358,6 +357,25 @@ object XGBoost extends Serializable {
} }
} }
private def cacheData(ifCacheDataBoolean: Boolean, input: RDD[_]): RDD[_] = {
if (ifCacheDataBoolean) input.persist(StorageLevel.MEMORY_AND_DISK) else input
}
private def composeInputData(
trainingData: RDD[XGBLabeledPoint],
ifCacheDataBoolean: Boolean,
hasGroup: Boolean,
nWorkers: Int): Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]] = {
if (hasGroup) {
val repartitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
Left(cacheData(ifCacheDataBoolean, repartitionedData).
asInstanceOf[RDD[Array[XGBLabeledPoint]]])
} else {
val repartitionedData = repartitionForTraining(trainingData, nWorkers)
Right(cacheData(ifCacheDataBoolean, repartitionedData).asInstanceOf[RDD[XGBLabeledPoint]])
}
}
/** /**
* @return A tuple of the booster and the metrics used to build training summary * @return A tuple of the booster and the metrics used to build training summary
*/ */
@ -375,43 +393,63 @@ object XGBoost extends Serializable {
val sc = trainingData.sparkContext val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, checkpointPath) val checkpointManager = new CheckpointManager(sc, checkpointPath)
checkpointManager.cleanUpHigherVersions(round.asInstanceOf[Int]) checkpointManager.cleanUpHigherVersions(round.asInstanceOf[Int])
val transformedTrainingData = composeInputData(trainingData,
params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean], hasGroup, nWorkers)
var prevBooster = checkpointManager.loadCheckpointAsBooster var prevBooster = checkpointManager.loadCheckpointAsBooster
// Train for every ${savingRound} rounds and save the partially completed booster try {
checkpointManager.getCheckpointRounds(checkpointInterval, round).map { // Train for every ${savingRound} rounds and save the partially completed booster
checkpointRound: Int => checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
val tracker = startTracker(nWorkers, trackerConf) checkpointRound: Int =>
try { val tracker = startTracker(nWorkers, trackerConf)
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc) try {
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers) val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
val rabitEnv = tracker.getWorkerEnvs val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers,
val boostersAndMetrics = if (hasGroup) { nWorkers)
trainForRanking(trainingData, overriddenParams, rabitEnv, checkpointRound, val rabitEnv = tracker.getWorkerEnvs
prevBooster, evalSetsMap) val boostersAndMetrics = if (hasGroup) {
} else { trainForRanking(transformedTrainingData.left.get, overriddenParams, rabitEnv,
trainForNonRanking(trainingData, overriddenParams, rabitEnv, checkpointRound, checkpointRound, prevBooster, evalSetsMap)
prevBooster, evalSetsMap) } else {
} trainForNonRanking(transformedTrainingData.right.get, overriddenParams, rabitEnv,
val sparkJobThread = new Thread() { checkpointRound, prevBooster, evalSetsMap)
override def run() {
// force the job
boostersAndMetrics.foreachPartition(() => _)
} }
val sparkJobThread = new Thread() {
override def run() {
// force the job
boostersAndMetrics.foreachPartition(() => _)
}
}
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
logger.info(s"Rabit returns with exit code $trackerReturnVal")
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
boostersAndMetrics, sparkJobThread)
if (checkpointRound < round) {
prevBooster = booster
checkpointManager.updateCheckpoint(prevBooster)
}
(booster, metrics)
} finally {
tracker.stop()
} }
sparkJobThread.setUncaughtExceptionHandler(tracker) }.last
sparkJobThread.start() } finally {
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L)) uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean],
logger.info(s"Rabit returns with exit code $trackerReturnVal") transformedTrainingData)
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics, }
sparkJobThread) }
if (checkpointRound < round) {
prevBooster = booster private def uncacheTrainingData(
checkpointManager.updateCheckpoint(prevBooster) cacheTrainingSet: Boolean,
} transformedTrainingData: Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]]): Unit = {
(booster, metrics) if (cacheTrainingSet) {
} finally { if (transformedTrainingData.isLeft) {
tracker.stop() transformedTrainingData.left.get.unpersist()
} } else {
}.last transformedTrainingData.right.get.unpersist()
}
}
} }
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = { private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {

View File

@ -76,6 +76,12 @@ private[spark] trait LearningTaskParams extends Params {
final def getTrainTestRatio: Double = $(trainTestRatio) final def getTrainTestRatio: Double = $(trainTestRatio)
/**
* whether caching training data
*/
final val cacheTrainingSet = new BooleanParam(this, "cacheTrainingSet",
"whether caching training data")
/** /**
* If non-zero, the training will be stopped after a specified number * If non-zero, the training will be stopped after a specified number
* of consecutive increases in any evaluation metric. * of consecutive increases in any evaluation metric.
@ -95,7 +101,7 @@ private[spark] trait LearningTaskParams extends Params {
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics) final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
setDefault(objective -> "reg:squarederror", baseScore -> 0.5, setDefault(objective -> "reg:squarederror", baseScore -> 0.5,
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0) trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0, cacheTrainingSet -> false)
} }
private[spark] object LearningTaskParams { private[spark] object LearningTaskParams {

View File

@ -286,6 +286,33 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(error(nextModel._booster) < 0.1) assert(error(nextModel._booster) < 0.1)
} }
test("training with checkpoint boosters with cached training dataset") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers, "cacheTrainingSet" -> true)
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(
model.predict(testDM, outPutMargin = true), testDM)
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) > error(prevModel._booster))
assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1)
}
test("repartitionForTrainingGroup with group data") { test("repartitionForTrainingGroup with group data") {
// test different splits to cover the corner cases. // test different splits to cover the corner cases.
for (split <- 1 to 20) { for (split <- 1 to 20) {