From 7b5cbcc8468448245a1c2ab698d07cb05be75e94 Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Wed, 14 Aug 2019 10:57:47 -0700 Subject: [PATCH] [jvm-packages] cleaning checkpoint file after a successful training (#4754) * cleaning checkpoint file after a successful file * address comments --- .../scala/spark/CheckpointManager.scala | 22 +++++++- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 20 ++++--- .../spark/params/LearningTaskParams.scala | 7 +++ .../scala/spark/CheckpointManagerSuite.scala | 47 ++++++++++++++++ .../scala/spark/XGBoostGeneralSuite.scala | 54 ------------------- 5 files changed, 88 insertions(+), 62 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala index af872dc5d..e4705f181 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala @@ -53,6 +53,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) } } + def cleanPath(): Unit = { + if (checkpointPath != "") { + FileSystem.get(sc.hadoopConfiguration).delete(new Path(checkpointPath), true) + } + } + /** * Load existing checkpoint with the highest version as a Booster object * @@ -127,7 +133,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) object CheckpointManager { - private[spark] def extractParams(params: Map[String, Any]): (String, Int) = { + case class CheckpointParam( + checkpointPath: String, + checkpointInterval: Int, + skipCleanCheckpoint: Boolean) + + private[spark] def extractParams(params: Map[String, Any]): CheckpointParam = { val checkpointPath: String = params.get("checkpoint_path") match { case None => "" case Some(path: String) => path @@ -141,6 +152,13 @@ object CheckpointManager { case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" + " an instance of Int.") } - (checkpointPath, checkpointInterval) + + val skipCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match { + case None => false + case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint + case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" + + " an instance of Boolean") + } + CheckpointParam(checkpointPath, checkpointInterval, skipCheckpointFile) } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 2bfc97670..d14b84ddc 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -331,9 +331,11 @@ object XGBoost extends Serializable { case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" + " an instance of Long.") } - val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params) + val checkpointParam = + CheckpointManager.extractParams(params) (nWorkers, round, useExternalMemory, obj, eval, missing, trackerConf, timeoutRequestWorkers, - checkpointPath, checkpointInterval) + checkpointParam.checkpointPath, checkpointParam.checkpointInterval, + checkpointParam.skipCleanCheckpoint) } private def trainForNonRanking( @@ -343,7 +345,7 @@ object XGBoost extends Serializable { checkpointRound: Int, prevBooster: Booster, evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = { - val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) = + val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) = parameterFetchAndValidation(params, trainingData.sparkContext) if (evalSetsMap.isEmpty) { trainingData.mapPartitions(labeledPoints => { @@ -373,7 +375,7 @@ object XGBoost extends Serializable { checkpointRound: Int, prevBooster: Booster, evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = { - val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) = + val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) = parameterFetchAndValidation(params, trainingData.sparkContext) if (evalSetsMap.isEmpty) { trainingData.mapPartitions(labeledPointGroups => { @@ -427,7 +429,8 @@ object XGBoost extends Serializable { (Booster, Map[String, Array[Float]]) = { logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}") val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers, - checkpointPath, checkpointInterval) = parameterFetchAndValidation(params, + checkpointPath, checkpointInterval, skipCleanCheckpoint) = + parameterFetchAndValidation(params, trainingData.sparkContext) val sc = trainingData.sparkContext val checkpointManager = new CheckpointManager(sc, checkpointPath) @@ -437,7 +440,7 @@ object XGBoost extends Serializable { var prevBooster = checkpointManager.loadCheckpointAsBooster try { // Train for every ${savingRound} rounds and save the partially completed booster - checkpointManager.getCheckpointRounds(checkpointInterval, round).map { + val producedBooster = checkpointManager.getCheckpointRounds(checkpointInterval, round).map { checkpointRound: Int => val tracker = startTracker(nWorkers, trackerConf) try { @@ -473,6 +476,11 @@ object XGBoost extends Serializable { tracker.stop() } }.last + // we should delete the checkpoint directory after a successful training + if (!skipCleanCheckpoint) { + checkpointManager.cleanPath() + } + producedBooster } catch { case t: Throwable => // if the job was aborted due to an exception diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 414962d36..1512c85d0 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -80,6 +80,13 @@ private[spark] trait LearningTaskParams extends Params { final val cacheTrainingSet = new BooleanParam(this, "cacheTrainingSet", "whether caching training data") + /** + * whether cleaning checkpoint, always cleaning by default, having this parameter majorly for + * testing + */ + final val skipCleanCheckpoint = new BooleanParam(this, "skipCleanCheckpoint", + "whether cleaning checkpoint data") + /** * If non-zero, the training will be stopped after a specified number * of consecutive increases in any evaluation metric. diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala index c91343d06..ddeb48241 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala @@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File +import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} import org.scalatest.FunSuite import org.apache.hadoop.fs.{FileSystem, Path} @@ -67,4 +68,50 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7)) } + + private def trainingWithCheckpoint(cacheData: Boolean, skipCleanCheckpoint: Boolean): Unit = { + val eval = new EvalError() + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) + + val tmpPath = createTmpFolder("model1").toAbsolutePath.toString + val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map() + val skipCleanCheckpointMap = + if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map() + val paramMap = Map("eta" -> "1", "max_depth" -> 2, + "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, + "checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++ + skipCleanCheckpointMap + + val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training) + def error(model: Booster): Float = eval.eval( + model.predict(testDM, outPutMargin = true), testDM) + + if (skipCleanCheckpoint) { + // Check only one model is kept after training + val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) + assert(files.length == 1) + assert(files.head.getPath.getName == "8.model") + val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model") + // Train next model based on prev model + val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training) + assert(error(tmpModel) > error(prevModel._booster)) + assert(error(prevModel._booster) > error(nextModel._booster)) + assert(error(nextModel._booster) < 0.1) + } else { + assert(!FileSystem.get(sc.hadoopConfiguration).exists(new Path(tmpPath))) + } + } + + test("training with checkpoint boosters") { + trainingWithCheckpoint(cacheData = false, skipCleanCheckpoint = true) + } + + test("training with checkpoint boosters with cached training dataset") { + trainingWithCheckpoint(cacheData = true, skipCleanCheckpoint = true) + } + + test("the checkpoint file should be cleaned after a successful training") { + trainingWithCheckpoint(cacheData = false, skipCleanCheckpoint = false) + } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index fb29aecfc..f3492b2e3 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -179,60 +179,6 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest { assert(x < 0.1) } - test("training with checkpoint boosters") { - val eval = new EvalError() - val training = buildDataFrame(Classification.train) - val testDM = new DMatrix(Classification.test.iterator) - - val tmpPath = createTmpFolder("model1").toAbsolutePath.toString - val paramMap = Map("eta" -> "1", "max_depth" -> 2, - "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, - "checkpoint_interval" -> 2, "num_workers" -> numWorkers) - - val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training) - def error(model: Booster): Float = eval.eval( - model.predict(testDM, outPutMargin = true), testDM) - - // Check only one model is kept after training - val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) - assert(files.length == 1) - assert(files.head.getPath.getName == "8.model") - val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model") - - // Train next model based on prev model - val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training) - assert(error(tmpModel) > error(prevModel._booster)) - assert(error(prevModel._booster) > error(nextModel._booster)) - assert(error(nextModel._booster) < 0.1) - } - - test("training with checkpoint boosters with cached training dataset") { - val eval = new EvalError() - val training = buildDataFrame(Classification.train) - val testDM = new DMatrix(Classification.test.iterator) - - val tmpPath = createTmpFolder("model1").toAbsolutePath.toString - val paramMap = Map("eta" -> "1", "max_depth" -> 2, - "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, - "checkpoint_interval" -> 2, "num_workers" -> numWorkers, "cacheTrainingSet" -> true) - - val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training) - def error(model: Booster): Float = eval.eval( - model.predict(testDM, outPutMargin = true), testDM) - - // Check only one model is kept after training - val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) - assert(files.length == 1) - assert(files.head.getPath.getName == "8.model") - val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model") - - // Train next model based on prev model - val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training) - assert(error(tmpModel) > error(prevModel._booster)) - assert(error(prevModel._booster) > error(nextModel._booster)) - assert(error(nextModel._booster) < 0.1) - } - test("repartitionForTrainingGroup with group data") { // test different splits to cover the corner cases. for (split <- 1 to 20) {