From 359ed9c5bc8b83f6b5d6c10e130d07d4886c9fec Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Mon, 18 Mar 2019 10:13:28 +0800 Subject: [PATCH] [jvm-packages] add configuration flag to control whether to cache transformed training set (#4268) * control whether to cache data * uncache --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 120 ++++++++++++------ .../spark/params/LearningTaskParams.scala | 8 +- .../scala/spark/XGBoostGeneralSuite.scala | 27 ++++ 3 files changed, 113 insertions(+), 42 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 709601816..43ad1f1bd 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -33,6 +33,7 @@ import org.apache.commons.logging.LogFactory import org.apache.spark.rdd.RDD import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext} import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.storage.StorageLevel /** @@ -305,9 +306,8 @@ object XGBoost extends Serializable { evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = { val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) = parameterFetchAndValidation(params, trainingData.sparkContext) - val partitionedData = repartitionForTraining(trainingData, nWorkers) if (evalSetsMap.isEmpty) { - partitionedData.mapPartitions(labeledPoints => { + trainingData.mapPartitions(labeledPoints => { val watches = Watches.buildWatches(params, removeMissingValues(labeledPoints, missing), getCacheDirName(useExternalMemory)) @@ -315,7 +315,7 @@ object XGBoost extends Serializable { obj, eval, prevBooster) }).cache() } else { - coPartitionNoGroupSets(partitionedData, evalSetsMap, nWorkers).mapPartitions { + coPartitionNoGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions { nameAndLabeledPointSets => val watches = Watches.buildWatches( nameAndLabeledPointSets.map { @@ -328,7 +328,7 @@ object XGBoost extends Serializable { } private def trainForRanking( - trainingData: RDD[XGBLabeledPoint], + trainingData: RDD[Array[XGBLabeledPoint]], params: Map[String, Any], rabitEnv: java.util.Map[String, String], checkpointRound: Int, @@ -336,16 +336,15 @@ object XGBoost extends Serializable { evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = { val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) = parameterFetchAndValidation(params, trainingData.sparkContext) - val partitionedTrainingSet = repartitionForTrainingGroup(trainingData, nWorkers) if (evalSetsMap.isEmpty) { - partitionedTrainingSet.mapPartitions(labeledPointGroups => { + trainingData.mapPartitions(labeledPointGroups => { val watches = Watches.buildWatchesWithGroup(params, removeMissingValuesWithGroup(labeledPointGroups, missing), getCacheDirName(useExternalMemory)) buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster) }).cache() } else { - coPartitionGroupSets(partitionedTrainingSet, evalSetsMap, nWorkers).mapPartitions( + coPartitionGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions( labeledPointGroupSets => { val watches = Watches.buildWatchesWithGroup( labeledPointGroupSets.map { @@ -358,6 +357,25 @@ object XGBoost extends Serializable { } } + private def cacheData(ifCacheDataBoolean: Boolean, input: RDD[_]): RDD[_] = { + if (ifCacheDataBoolean) input.persist(StorageLevel.MEMORY_AND_DISK) else input + } + + private def composeInputData( + trainingData: RDD[XGBLabeledPoint], + ifCacheDataBoolean: Boolean, + hasGroup: Boolean, + nWorkers: Int): Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]] = { + if (hasGroup) { + val repartitionedData = repartitionForTrainingGroup(trainingData, nWorkers) + Left(cacheData(ifCacheDataBoolean, repartitionedData). + asInstanceOf[RDD[Array[XGBLabeledPoint]]]) + } else { + val repartitionedData = repartitionForTraining(trainingData, nWorkers) + Right(cacheData(ifCacheDataBoolean, repartitionedData).asInstanceOf[RDD[XGBLabeledPoint]]) + } + } + /** * @return A tuple of the booster and the metrics used to build training summary */ @@ -375,43 +393,63 @@ object XGBoost extends Serializable { val sc = trainingData.sparkContext val checkpointManager = new CheckpointManager(sc, checkpointPath) checkpointManager.cleanUpHigherVersions(round.asInstanceOf[Int]) + val transformedTrainingData = composeInputData(trainingData, + params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean], hasGroup, nWorkers) var prevBooster = checkpointManager.loadCheckpointAsBooster - // Train for every ${savingRound} rounds and save the partially completed booster - checkpointManager.getCheckpointRounds(checkpointInterval, round).map { - checkpointRound: Int => - val tracker = startTracker(nWorkers, trackerConf) - try { - val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc) - val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers) - val rabitEnv = tracker.getWorkerEnvs - val boostersAndMetrics = if (hasGroup) { - trainForRanking(trainingData, overriddenParams, rabitEnv, checkpointRound, - prevBooster, evalSetsMap) - } else { - trainForNonRanking(trainingData, overriddenParams, rabitEnv, checkpointRound, - prevBooster, evalSetsMap) - } - val sparkJobThread = new Thread() { - override def run() { - // force the job - boostersAndMetrics.foreachPartition(() => _) + try { + // Train for every ${savingRound} rounds and save the partially completed booster + checkpointManager.getCheckpointRounds(checkpointInterval, round).map { + checkpointRound: Int => + val tracker = startTracker(nWorkers, trackerConf) + try { + val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc) + val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, + nWorkers) + val rabitEnv = tracker.getWorkerEnvs + val boostersAndMetrics = if (hasGroup) { + trainForRanking(transformedTrainingData.left.get, overriddenParams, rabitEnv, + checkpointRound, prevBooster, evalSetsMap) + } else { + trainForNonRanking(transformedTrainingData.right.get, overriddenParams, rabitEnv, + checkpointRound, prevBooster, evalSetsMap) } + val sparkJobThread = new Thread() { + override def run() { + // force the job + boostersAndMetrics.foreachPartition(() => _) + } + } + sparkJobThread.setUncaughtExceptionHandler(tracker) + sparkJobThread.start() + val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L)) + logger.info(s"Rabit returns with exit code $trackerReturnVal") + val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, + boostersAndMetrics, sparkJobThread) + if (checkpointRound < round) { + prevBooster = booster + checkpointManager.updateCheckpoint(prevBooster) + } + (booster, metrics) + } finally { + tracker.stop() } - sparkJobThread.setUncaughtExceptionHandler(tracker) - sparkJobThread.start() - val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L)) - logger.info(s"Rabit returns with exit code $trackerReturnVal") - val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics, - sparkJobThread) - if (checkpointRound < round) { - prevBooster = booster - checkpointManager.updateCheckpoint(prevBooster) - } - (booster, metrics) - } finally { - tracker.stop() - } - }.last + }.last + } finally { + uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean], + transformedTrainingData) + } + } + + private def uncacheTrainingData( + cacheTrainingSet: Boolean, + transformedTrainingData: Either[RDD[Array[XGBLabeledPoint]], RDD[XGBLabeledPoint]]): Unit = { + if (cacheTrainingSet) { + if (transformedTrainingData.isLeft) { + transformedTrainingData.left.get.unpersist() + } else { + transformedTrainingData.right.get.unpersist() + } + } } private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 5818a7af3..a621305b0 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -76,6 +76,12 @@ private[spark] trait LearningTaskParams extends Params { final def getTrainTestRatio: Double = $(trainTestRatio) + /** + * whether caching training data + */ + final val cacheTrainingSet = new BooleanParam(this, "cacheTrainingSet", + "whether caching training data") + /** * If non-zero, the training will be stopped after a specified number * of consecutive increases in any evaluation metric. @@ -95,7 +101,7 @@ private[spark] trait LearningTaskParams extends Params { final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics) setDefault(objective -> "reg:squarederror", baseScore -> 0.5, - trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0) + trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0, cacheTrainingSet -> false) } private[spark] object LearningTaskParams { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index d5ecb3eca..09b5a8883 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -286,6 +286,33 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { assert(error(nextModel._booster) < 0.1) } + test("training with checkpoint boosters with cached training dataset") { + val eval = new EvalError() + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) + + val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString + val paramMap = Map("eta" -> "1", "max_depth" -> 2, + "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, + "checkpoint_interval" -> 2, "num_workers" -> numWorkers, "cacheTrainingSet" -> true) + + val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training) + def error(model: Booster): Float = eval.eval( + model.predict(testDM, outPutMargin = true), testDM) + + // Check only one model is kept after training + val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) + assert(files.length == 1) + assert(files.head.getPath.getName == "8.model") + val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model") + + // Train next model based on prev model + val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training) + assert(error(tmpModel) > error(prevModel._booster)) + assert(error(prevModel._booster) > error(nextModel._booster)) + assert(error(nextModel._booster) < 0.1) + } + test("repartitionForTrainingGroup with group data") { // test different splits to cover the corner cases. for (split <- 1 to 20) {