[jvm-packages] add configuration flag to control whether to cache transformed training set (#4268)

* control whether to cache data

* uncache
This commit is contained in:
Nan Zhu 2019-03-18 10:13:28 +08:00 committed by GitHub
parent 29a1356669
commit 359ed9c5bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 42 deletions

View File

@ -33,6 +33,7 @@ import org.apache.commons.logging.LogFactory
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.storage.StorageLevel
/**
@ -305,9 +306,8 @@ object XGBoost extends Serializable {
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
parameterFetchAndValidation(params, trainingData.sparkContext)
val partitionedData = repartitionForTraining(trainingData, nWorkers)
if (evalSetsMap.isEmpty) {
partitionedData.mapPartitions(labeledPoints => {
trainingData.mapPartitions(labeledPoints => {
val watches = Watches.buildWatches(params,
removeMissingValues(labeledPoints, missing),
getCacheDirName(useExternalMemory))
@ -315,7 +315,7 @@ object XGBoost extends Serializable {
obj, eval, prevBooster)
}).cache()
} else {
coPartitionNoGroupSets(partitionedData, evalSetsMap, nWorkers).mapPartitions {
coPartitionNoGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions {
nameAndLabeledPointSets =>
val watches = Watches.buildWatches(
nameAndLabeledPointSets.map {
@ -328,7 +328,7 @@ object XGBoost extends Serializable {
}
private def trainForRanking(
trainingData: RDD[XGBLabeledPoint],
trainingData: RDD[Array[XGBLabeledPoint]],
params: Map[String, Any],
rabitEnv: java.util.Map[String, String],
checkpointRound: Int,
@ -336,16 +336,15 @@ object XGBoost extends Serializable {
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
parameterFetchAndValidation(params, trainingData.sparkContext)
val partitionedTrainingSet = repartitionForTrainingGroup(trainingData, nWorkers)
if (evalSetsMap.isEmpty) {
partitionedTrainingSet.mapPartitions(labeledPointGroups => {
trainingData.mapPartitions(labeledPointGroups => {
val watches = Watches.buildWatchesWithGroup(params,
removeMissingValuesWithGroup(labeledPointGroups, missing),
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
}).cache()
} else {
coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, nWorkers).mapPartitions(
coPartitionGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions(
labeledPointGroupSets => {
val watches = Watches.buildWatchesWithGroup(
labeledPointGroupSets.map {
@ -358,6 +357,25 @@ object XGBoost extends Serializable {
}
}
private def cacheData(ifCacheDataBoolean: Boolean, input: RDD[_]): RDD[_] = {
if (ifCacheDataBoolean) input.persist(StorageLevel.MEMORY_AND_DISK) else input
}
private def composeInputData(
trainingData: RDD[XGBLabeledPoint],
ifCacheDataBoolean: Boolean,
hasGroup: Boolean,
nWorkers: Int): Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]] = {
if (hasGroup) {
val repartitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
Left(cacheData(ifCacheDataBoolean, repartitionedData).
asInstanceOf[RDD[Array[XGBLabeledPoint]]])
} else {
val repartitionedData = repartitionForTraining(trainingData, nWorkers)
Right(cacheData(ifCacheDataBoolean, repartitionedData).asInstanceOf[RDD[XGBLabeledPoint]])
}
}
/**
* @return A tuple of the booster and the metrics used to build training summary
*/
@ -375,43 +393,63 @@ object XGBoost extends Serializable {
val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, checkpointPath)
checkpointManager.cleanUpHigherVersions(round.asInstanceOf[Int])
val transformedTrainingData = composeInputData(trainingData,
params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean], hasGroup, nWorkers)
var prevBooster = checkpointManager.loadCheckpointAsBooster
// Train for every ${savingRound} rounds and save the partially completed booster
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
checkpointRound: Int =>
val tracker = startTracker(nWorkers, trackerConf)
try {
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
val rabitEnv = tracker.getWorkerEnvs
val boostersAndMetrics = if (hasGroup) {
trainForRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
prevBooster, evalSetsMap)
} else {
trainForNonRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
prevBooster, evalSetsMap)
}
val sparkJobThread = new Thread() {
override def run() {
// force the job
boostersAndMetrics.foreachPartition(() => _)
try {
// Train for every ${savingRound} rounds and save the partially completed booster
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
checkpointRound: Int =>
val tracker = startTracker(nWorkers, trackerConf)
try {
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers,
nWorkers)
val rabitEnv = tracker.getWorkerEnvs
val boostersAndMetrics = if (hasGroup) {
trainForRanking(transformedTrainingData.left.get, overriddenParams, rabitEnv,
checkpointRound, prevBooster, evalSetsMap)
} else {
trainForNonRanking(transformedTrainingData.right.get, overriddenParams, rabitEnv,
checkpointRound, prevBooster, evalSetsMap)
}
val sparkJobThread = new Thread() {
override def run() {
// force the job
boostersAndMetrics.foreachPartition(() => _)
}
}
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
logger.info(s"Rabit returns with exit code $trackerReturnVal")
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
boostersAndMetrics, sparkJobThread)
if (checkpointRound < round) {
prevBooster = booster
checkpointManager.updateCheckpoint(prevBooster)
}
(booster, metrics)
} finally {
tracker.stop()
}
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
logger.info(s"Rabit returns with exit code $trackerReturnVal")
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
sparkJobThread)
if (checkpointRound < round) {
prevBooster = booster
checkpointManager.updateCheckpoint(prevBooster)
}
(booster, metrics)
} finally {
tracker.stop()
}
}.last
}.last
} finally {
uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean],
transformedTrainingData)
}
}
private def uncacheTrainingData(
cacheTrainingSet: Boolean,
transformedTrainingData: Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]]): Unit = {
if (cacheTrainingSet) {
if (transformedTrainingData.isLeft) {
transformedTrainingData.left.get.unpersist()
} else {
transformedTrainingData.right.get.unpersist()
}
}
}
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {

View File

@ -76,6 +76,12 @@ private[spark] trait LearningTaskParams extends Params {
final def getTrainTestRatio: Double = $(trainTestRatio)
/**
* whether caching training data
*/
final val cacheTrainingSet = new BooleanParam(this, "cacheTrainingSet",
"whether caching training data")
/**
* If non-zero, the training will be stopped after a specified number
* of consecutive increases in any evaluation metric.
@ -95,7 +101,7 @@ private[spark] trait LearningTaskParams extends Params {
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
setDefault(objective -> "reg:squarederror", baseScore -> 0.5,
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0)
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0, cacheTrainingSet -> false)
}
private[spark] object LearningTaskParams {

View File

@ -286,6 +286,33 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(error(nextModel._booster) < 0.1)
}
test("training with checkpoint boosters with cached training dataset") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers, "cacheTrainingSet" -> true)
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(
model.predict(testDM, outPutMargin = true), testDM)
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) > error(prevModel._booster))
assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1)
}
test("repartitionForTrainingGroup with group data") {
// test different splits to cover the corner cases.
for (split <- 1 to 20) {