[jvm-packages] add configuration flag to control whether to cache transformed training set (#4268)
* control whether to cache data * uncache
This commit is contained in:
parent
29a1356669
commit
359ed9c5bc
@ -33,6 +33,7 @@ import org.apache.commons.logging.LogFactory
|
|||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||||
import org.apache.spark.sql.{DataFrame, SparkSession}
|
import org.apache.spark.sql.{DataFrame, SparkSession}
|
||||||
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -305,9 +306,8 @@ object XGBoost extends Serializable {
|
|||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
||||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||||
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
|
||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
partitionedData.mapPartitions(labeledPoints => {
|
trainingData.mapPartitions(labeledPoints => {
|
||||||
val watches = Watches.buildWatches(params,
|
val watches = Watches.buildWatches(params,
|
||||||
removeMissingValues(labeledPoints, missing),
|
removeMissingValues(labeledPoints, missing),
|
||||||
getCacheDirName(useExternalMemory))
|
getCacheDirName(useExternalMemory))
|
||||||
@ -315,7 +315,7 @@ object XGBoost extends Serializable {
|
|||||||
obj, eval, prevBooster)
|
obj, eval, prevBooster)
|
||||||
}).cache()
|
}).cache()
|
||||||
} else {
|
} else {
|
||||||
coPartitionNoGroupSets(partitionedData, evalSetsMap, nWorkers).mapPartitions {
|
coPartitionNoGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions {
|
||||||
nameAndLabeledPointSets =>
|
nameAndLabeledPointSets =>
|
||||||
val watches = Watches.buildWatches(
|
val watches = Watches.buildWatches(
|
||||||
nameAndLabeledPointSets.map {
|
nameAndLabeledPointSets.map {
|
||||||
@ -328,7 +328,7 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private def trainForRanking(
|
private def trainForRanking(
|
||||||
trainingData: RDD[XGBLabeledPoint],
|
trainingData: RDD[Array[XGBLabeledPoint]],
|
||||||
params: Map[String, Any],
|
params: Map[String, Any],
|
||||||
rabitEnv: java.util.Map[String, String],
|
rabitEnv: java.util.Map[String, String],
|
||||||
checkpointRound: Int,
|
checkpointRound: Int,
|
||||||
@ -336,16 +336,15 @@ object XGBoost extends Serializable {
|
|||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
||||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||||
val partitionedTrainingSet = repartitionForTrainingGroup(trainingData, nWorkers)
|
|
||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
partitionedTrainingSet.mapPartitions(labeledPointGroups => {
|
trainingData.mapPartitions(labeledPointGroups => {
|
||||||
val watches = Watches.buildWatchesWithGroup(params,
|
val watches = Watches.buildWatchesWithGroup(params,
|
||||||
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
||||||
getCacheDirName(useExternalMemory))
|
getCacheDirName(useExternalMemory))
|
||||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
|
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
|
||||||
}).cache()
|
}).cache()
|
||||||
} else {
|
} else {
|
||||||
coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, nWorkers).mapPartitions(
|
coPartitionGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions(
|
||||||
labeledPointGroupSets => {
|
labeledPointGroupSets => {
|
||||||
val watches = Watches.buildWatchesWithGroup(
|
val watches = Watches.buildWatchesWithGroup(
|
||||||
labeledPointGroupSets.map {
|
labeledPointGroupSets.map {
|
||||||
@ -358,6 +357,25 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private def cacheData(ifCacheDataBoolean: Boolean, input: RDD[_]): RDD[_] = {
|
||||||
|
if (ifCacheDataBoolean) input.persist(StorageLevel.MEMORY_AND_DISK) else input
|
||||||
|
}
|
||||||
|
|
||||||
|
private def composeInputData(
|
||||||
|
trainingData: RDD[XGBLabeledPoint],
|
||||||
|
ifCacheDataBoolean: Boolean,
|
||||||
|
hasGroup: Boolean,
|
||||||
|
nWorkers: Int): Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]] = {
|
||||||
|
if (hasGroup) {
|
||||||
|
val repartitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
|
||||||
|
Left(cacheData(ifCacheDataBoolean, repartitionedData).
|
||||||
|
asInstanceOf[RDD[Array[XGBLabeledPoint]]])
|
||||||
|
} else {
|
||||||
|
val repartitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||||
|
Right(cacheData(ifCacheDataBoolean, repartitionedData).asInstanceOf[RDD[XGBLabeledPoint]])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return A tuple of the booster and the metrics used to build training summary
|
* @return A tuple of the booster and the metrics used to build training summary
|
||||||
*/
|
*/
|
||||||
@ -375,43 +393,63 @@ object XGBoost extends Serializable {
|
|||||||
val sc = trainingData.sparkContext
|
val sc = trainingData.sparkContext
|
||||||
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
||||||
checkpointManager.cleanUpHigherVersions(round.asInstanceOf[Int])
|
checkpointManager.cleanUpHigherVersions(round.asInstanceOf[Int])
|
||||||
|
val transformedTrainingData = composeInputData(trainingData,
|
||||||
|
params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean], hasGroup, nWorkers)
|
||||||
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
||||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
try {
|
||||||
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||||
checkpointRound: Int =>
|
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
||||||
val tracker = startTracker(nWorkers, trackerConf)
|
checkpointRound: Int =>
|
||||||
try {
|
val tracker = startTracker(nWorkers, trackerConf)
|
||||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
try {
|
||||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
||||||
val rabitEnv = tracker.getWorkerEnvs
|
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers,
|
||||||
val boostersAndMetrics = if (hasGroup) {
|
nWorkers)
|
||||||
trainForRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
|
val rabitEnv = tracker.getWorkerEnvs
|
||||||
prevBooster, evalSetsMap)
|
val boostersAndMetrics = if (hasGroup) {
|
||||||
} else {
|
trainForRanking(transformedTrainingData.left.get, overriddenParams, rabitEnv,
|
||||||
trainForNonRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
|
checkpointRound, prevBooster, evalSetsMap)
|
||||||
prevBooster, evalSetsMap)
|
} else {
|
||||||
}
|
trainForNonRanking(transformedTrainingData.right.get, overriddenParams, rabitEnv,
|
||||||
val sparkJobThread = new Thread() {
|
checkpointRound, prevBooster, evalSetsMap)
|
||||||
override def run() {
|
|
||||||
// force the job
|
|
||||||
boostersAndMetrics.foreachPartition(() => _)
|
|
||||||
}
|
}
|
||||||
|
val sparkJobThread = new Thread() {
|
||||||
|
override def run() {
|
||||||
|
// force the job
|
||||||
|
boostersAndMetrics.foreachPartition(() => _)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||||
|
sparkJobThread.start()
|
||||||
|
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||||
|
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||||
|
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
|
||||||
|
boostersAndMetrics, sparkJobThread)
|
||||||
|
if (checkpointRound < round) {
|
||||||
|
prevBooster = booster
|
||||||
|
checkpointManager.updateCheckpoint(prevBooster)
|
||||||
|
}
|
||||||
|
(booster, metrics)
|
||||||
|
} finally {
|
||||||
|
tracker.stop()
|
||||||
}
|
}
|
||||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
}.last
|
||||||
sparkJobThread.start()
|
} finally {
|
||||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean],
|
||||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
transformedTrainingData)
|
||||||
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
|
}
|
||||||
sparkJobThread)
|
}
|
||||||
if (checkpointRound < round) {
|
|
||||||
prevBooster = booster
|
private def uncacheTrainingData(
|
||||||
checkpointManager.updateCheckpoint(prevBooster)
|
cacheTrainingSet: Boolean,
|
||||||
}
|
transformedTrainingData: Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]]): Unit = {
|
||||||
(booster, metrics)
|
if (cacheTrainingSet) {
|
||||||
} finally {
|
if (transformedTrainingData.isLeft) {
|
||||||
tracker.stop()
|
transformedTrainingData.left.get.unpersist()
|
||||||
}
|
} else {
|
||||||
}.last
|
transformedTrainingData.right.get.unpersist()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
|
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
|
||||||
|
|||||||
@ -76,6 +76,12 @@ private[spark] trait LearningTaskParams extends Params {
|
|||||||
|
|
||||||
final def getTrainTestRatio: Double = $(trainTestRatio)
|
final def getTrainTestRatio: Double = $(trainTestRatio)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* whether caching training data
|
||||||
|
*/
|
||||||
|
final val cacheTrainingSet = new BooleanParam(this, "cacheTrainingSet",
|
||||||
|
"whether caching training data")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* If non-zero, the training will be stopped after a specified number
|
* If non-zero, the training will be stopped after a specified number
|
||||||
* of consecutive increases in any evaluation metric.
|
* of consecutive increases in any evaluation metric.
|
||||||
@ -95,7 +101,7 @@ private[spark] trait LearningTaskParams extends Params {
|
|||||||
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
|
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
|
||||||
|
|
||||||
setDefault(objective -> "reg:squarederror", baseScore -> 0.5,
|
setDefault(objective -> "reg:squarederror", baseScore -> 0.5,
|
||||||
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0)
|
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0, cacheTrainingSet -> false)
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] object LearningTaskParams {
|
private[spark] object LearningTaskParams {
|
||||||
|
|||||||
@ -286,6 +286,33 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
assert(error(nextModel._booster) < 0.1)
|
assert(error(nextModel._booster) < 0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("training with checkpoint boosters with cached training dataset") {
|
||||||
|
val eval = new EvalError()
|
||||||
|
val training = buildDataFrame(Classification.train)
|
||||||
|
val testDM = new DMatrix(Classification.test.iterator)
|
||||||
|
|
||||||
|
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
||||||
|
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||||
|
"checkpoint_interval" -> 2, "num_workers" -> numWorkers, "cacheTrainingSet" -> true)
|
||||||
|
|
||||||
|
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
||||||
|
def error(model: Booster): Float = eval.eval(
|
||||||
|
model.predict(testDM, outPutMargin = true), testDM)
|
||||||
|
|
||||||
|
// Check only one model is kept after training
|
||||||
|
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
|
assert(files.length == 1)
|
||||||
|
assert(files.head.getPath.getName == "8.model")
|
||||||
|
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||||
|
|
||||||
|
// Train next model based on prev model
|
||||||
|
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||||
|
assert(error(tmpModel) > error(prevModel._booster))
|
||||||
|
assert(error(prevModel._booster) > error(nextModel._booster))
|
||||||
|
assert(error(nextModel._booster) < 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
test("repartitionForTrainingGroup with group data") {
|
test("repartitionForTrainingGroup with group data") {
|
||||||
// test different splits to cover the corner cases.
|
// test different splits to cover the corner cases.
|
||||||
for (split <- 1 to 20) {
|
for (split <- 1 to 20) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user