[jvm-packages] add configuration flag to control whether to cache transformed training set (#4268)
* control whether to cache data * uncache
This commit is contained in:
parent
29a1356669
commit
359ed9c5bc
@ -33,6 +33,7 @@ import org.apache.commons.logging.LogFactory
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||
import org.apache.spark.sql.{DataFrame, SparkSession}
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
|
||||
/**
|
||||
@ -305,9 +306,8 @@ object XGBoost extends Serializable {
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||
if (evalSetsMap.isEmpty) {
|
||||
partitionedData.mapPartitions(labeledPoints => {
|
||||
trainingData.mapPartitions(labeledPoints => {
|
||||
val watches = Watches.buildWatches(params,
|
||||
removeMissingValues(labeledPoints, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
@ -315,7 +315,7 @@ object XGBoost extends Serializable {
|
||||
obj, eval, prevBooster)
|
||||
}).cache()
|
||||
} else {
|
||||
coPartitionNoGroupSets(partitionedData, evalSetsMap, nWorkers).mapPartitions {
|
||||
coPartitionNoGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions {
|
||||
nameAndLabeledPointSets =>
|
||||
val watches = Watches.buildWatches(
|
||||
nameAndLabeledPointSets.map {
|
||||
@ -328,7 +328,7 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
|
||||
private def trainForRanking(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
trainingData: RDD[Array[XGBLabeledPoint]],
|
||||
params: Map[String, Any],
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
checkpointRound: Int,
|
||||
@ -336,16 +336,15 @@ object XGBoost extends Serializable {
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||
val partitionedTrainingSet = repartitionForTrainingGroup(trainingData, nWorkers)
|
||||
if (evalSetsMap.isEmpty) {
|
||||
partitionedTrainingSet.mapPartitions(labeledPointGroups => {
|
||||
trainingData.mapPartitions(labeledPointGroups => {
|
||||
val watches = Watches.buildWatchesWithGroup(params,
|
||||
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
|
||||
}).cache()
|
||||
} else {
|
||||
coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, nWorkers).mapPartitions(
|
||||
coPartitionGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions(
|
||||
labeledPointGroupSets => {
|
||||
val watches = Watches.buildWatchesWithGroup(
|
||||
labeledPointGroupSets.map {
|
||||
@ -358,6 +357,25 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
private def cacheData(ifCacheDataBoolean: Boolean, input: RDD[_]): RDD[_] = {
|
||||
if (ifCacheDataBoolean) input.persist(StorageLevel.MEMORY_AND_DISK) else input
|
||||
}
|
||||
|
||||
private def composeInputData(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
ifCacheDataBoolean: Boolean,
|
||||
hasGroup: Boolean,
|
||||
nWorkers: Int): Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]] = {
|
||||
if (hasGroup) {
|
||||
val repartitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
|
||||
Left(cacheData(ifCacheDataBoolean, repartitionedData).
|
||||
asInstanceOf[RDD[Array[XGBLabeledPoint]]])
|
||||
} else {
|
||||
val repartitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||
Right(cacheData(ifCacheDataBoolean, repartitionedData).asInstanceOf[RDD[XGBLabeledPoint]])
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A tuple of the booster and the metrics used to build training summary
|
||||
*/
|
||||
@ -375,43 +393,63 @@ object XGBoost extends Serializable {
|
||||
val sc = trainingData.sparkContext
|
||||
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
||||
checkpointManager.cleanUpHigherVersions(round.asInstanceOf[Int])
|
||||
val transformedTrainingData = composeInputData(trainingData,
|
||||
params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean], hasGroup, nWorkers)
|
||||
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
||||
checkpointRound: Int =>
|
||||
val tracker = startTracker(nWorkers, trackerConf)
|
||||
try {
|
||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
val boostersAndMetrics = if (hasGroup) {
|
||||
trainForRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
|
||||
prevBooster, evalSetsMap)
|
||||
} else {
|
||||
trainForNonRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
|
||||
prevBooster, evalSetsMap)
|
||||
}
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
boostersAndMetrics.foreachPartition(() => _)
|
||||
try {
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
||||
checkpointRound: Int =>
|
||||
val tracker = startTracker(nWorkers, trackerConf)
|
||||
try {
|
||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers,
|
||||
nWorkers)
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
val boostersAndMetrics = if (hasGroup) {
|
||||
trainForRanking(transformedTrainingData.left.get, overriddenParams, rabitEnv,
|
||||
checkpointRound, prevBooster, evalSetsMap)
|
||||
} else {
|
||||
trainForNonRanking(transformedTrainingData.right.get, overriddenParams, rabitEnv,
|
||||
checkpointRound, prevBooster, evalSetsMap)
|
||||
}
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
boostersAndMetrics.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkJobThread.start()
|
||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
|
||||
boostersAndMetrics, sparkJobThread)
|
||||
if (checkpointRound < round) {
|
||||
prevBooster = booster
|
||||
checkpointManager.updateCheckpoint(prevBooster)
|
||||
}
|
||||
(booster, metrics)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
}
|
||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkJobThread.start()
|
||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
|
||||
sparkJobThread)
|
||||
if (checkpointRound < round) {
|
||||
prevBooster = booster
|
||||
checkpointManager.updateCheckpoint(prevBooster)
|
||||
}
|
||||
(booster, metrics)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
}
|
||||
}.last
|
||||
}.last
|
||||
} finally {
|
||||
uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean],
|
||||
transformedTrainingData)
|
||||
}
|
||||
}
|
||||
|
||||
private def uncacheTrainingData(
|
||||
cacheTrainingSet: Boolean,
|
||||
transformedTrainingData: Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]]): Unit = {
|
||||
if (cacheTrainingSet) {
|
||||
if (transformedTrainingData.isLeft) {
|
||||
transformedTrainingData.left.get.unpersist()
|
||||
} else {
|
||||
transformedTrainingData.right.get.unpersist()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
|
||||
|
||||
@ -76,6 +76,12 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
|
||||
final def getTrainTestRatio: Double = $(trainTestRatio)
|
||||
|
||||
/**
|
||||
* whether caching training data
|
||||
*/
|
||||
final val cacheTrainingSet = new BooleanParam(this, "cacheTrainingSet",
|
||||
"whether caching training data")
|
||||
|
||||
/**
|
||||
* If non-zero, the training will be stopped after a specified number
|
||||
* of consecutive increases in any evaluation metric.
|
||||
@ -95,7 +101,7 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
|
||||
|
||||
setDefault(objective -> "reg:squarederror", baseScore -> 0.5,
|
||||
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0)
|
||||
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0, cacheTrainingSet -> false)
|
||||
}
|
||||
|
||||
private[spark] object LearningTaskParams {
|
||||
|
||||
@ -286,6 +286,33 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
assert(error(nextModel._booster) < 0.1)
|
||||
}
|
||||
|
||||
test("training with checkpoint boosters with cached training dataset") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
|
||||
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers, "cacheTrainingSet" -> true)
|
||||
|
||||
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
||||
def error(model: Booster): Float = eval.eval(
|
||||
model.predict(testDM, outPutMargin = true), testDM)
|
||||
|
||||
// Check only one model is kept after training
|
||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||
|
||||
// Train next model based on prev model
|
||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||
assert(error(tmpModel) > error(prevModel._booster))
|
||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
||||
assert(error(nextModel._booster) < 0.1)
|
||||
}
|
||||
|
||||
test("repartitionForTrainingGroup with group data") {
|
||||
// test different splits to cover the corner cases.
|
||||
for (split <- 1 to 20) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user