From 3f3f54bcad1a62661ff5dbf276c2f4f0a459c5e4 Mon Sep 17 00:00:00 2001 From: Yun Ni Date: Tue, 16 Jan 2018 08:16:55 -0800 Subject: [PATCH] [jvm-packages] Update docs and unify the terminology (#3024) * [jvm-packages] Move cache files to tmp dir and delete on exit * [jvm-packages] Update docs and unify terminology * Address CR Comments --- .../scala/spark/CheckpointManager.scala | 54 ++++++++++--------- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 24 ++++----- .../scala/spark/params/GeneralParams.scala | 21 ++++---- .../scala/spark/CheckpointManagerSuite.scala | 20 +++---- .../scala/spark/XGBoostGeneralSuite.scala | 4 +- 5 files changed, 65 insertions(+), 58 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala index ae7e296ad..3756c152c 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala @@ -22,12 +22,16 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext /** - * A class which allows user to save checkpoint boosters every a few rounds. If a previous job - * fails, the job can restart training from a saved booster instead of from scratch. This class + * A class which allows user to save checkpoints every a few rounds. If a previous job fails, + * the job can restart training from a saved checkpoints instead of from scratch. This class * provides interface and helper methods for the checkpoint functionality. * + * NOTE: This checkpoint is different from Rabit checkpoint. Rabit checkpoint is a native-level + * checkpoint stored in executor memory. This is a checkpoint which Spark driver store on HDFS + * for every a few iterations. + * * @param sc the sparkContext object - * @param checkpointPath the hdfs path to store checkpoint boosters + * @param checkpointPath the hdfs path to store checkpoints */ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) { private val logger = LogFactory.getLog("XGBoostSpark") @@ -49,11 +53,11 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) } /** - * Load existing checkpoint with the highest version. + * Load existing checkpoint with the highest version as a Booster object * * @return the booster with the highest version, null if no checkpoints available. */ - private[spark] def loadBooster: Booster = { + private[spark] def loadCheckpointAsBooster: Booster = { val versions = getExistingVersions if (versions.nonEmpty) { val version = versions.max @@ -68,16 +72,16 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) } /** - * Clean up all previous models and save a new model + * Clean up all previous checkpoints and save a new checkpoint * - * @param model the xgboost model to save + * @param checkpoint the checkpoint to save as an XGBoostModel */ - private[spark] def updateModel(model: XGBoostModel): Unit = { + private[spark] def updateCheckpoint(checkpoint: XGBoostModel): Unit = { val fs = FileSystem.get(sc.hadoopConfiguration) val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version))) - val fullPath = getPath(model.version) - logger.info(s"Saving checkpoint model with version ${model.version} to $fullPath") - model.saveModelAsHadoopFile(fullPath)(sc) + val fullPath = getPath(checkpoint.version) + logger.info(s"Saving checkpoint model with version ${checkpoint.version} to $fullPath") + checkpoint.saveModelAsHadoopFile(fullPath)(sc) prevModelPaths.foreach(path => fs.delete(path, true)) } @@ -95,22 +99,22 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) } /** - * Calculate a list of checkpoint rounds to save checkpoints based on the savingFreq and - * total number of rounds for the training. Concretely, the saving rounds start with - * prevRounds + savingFreq, and increase by savingFreq in each step until it reaches total - * number of rounds. If savingFreq is 0, the checkpoint will be disabled and the method - * returns Seq(round) + * Calculate a list of checkpoint rounds to save checkpoints based on the checkpointInterval + * and total number of rounds for the training. Concretely, the checkpoint rounds start with + * prevRounds + checkpointInterval, and increase by checkpointInterval in each step until it + * reaches total number of rounds. If checkpointInterval is 0, the checkpoint will be disabled + * and the method returns Seq(round) * - * @param savingFreq the increase on rounds during each step of training + * @param checkpointInterval Period (in iterations) between checkpoints. * @param round the total number of rounds for the training * @return a seq of integers, each represent the index of round to save the checkpoints */ - private[spark] def getSavingRounds(savingFreq: Int, round: Int): Seq[Int] = { - if (checkpointPath.nonEmpty && savingFreq > 0) { + private[spark] def getCheckpointRounds(checkpointInterval: Int, round: Int): Seq[Int] = { + if (checkpointPath.nonEmpty && checkpointInterval > 0) { val prevRounds = getExistingVersions.map(_ / 2) - val firstSavingRound = (0 +: prevRounds).max + savingFreq - (firstSavingRound until round by savingFreq) :+ round - } else if (savingFreq <= 0) { + val firstCheckpointRound = (0 +: prevRounds).max + checkpointInterval + (firstCheckpointRound until round by checkpointInterval) :+ round + } else if (checkpointInterval <= 0) { Seq(round) } else { throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.") @@ -128,12 +132,12 @@ object CheckpointManager { " an instance of String.") } - val savingFreq: Int = params.get("saving_frequency") match { + val checkpointInterval: Int = params.get("checkpoint_interval") match { case None => 0 case Some(freq: Int) => freq - case _ => throw new IllegalArgumentException("parameter \"saving_frequency\" must be" + + case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" + " an instance of Int.") } - (checkpointPath, savingFreq) + (checkpointPath, checkpointInterval) } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 3d342ff07..736dfd060 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -17,6 +17,7 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File +import java.nio.file.Files import scala.collection.mutable import scala.util.Random @@ -120,11 +121,9 @@ object XGBoost extends Serializable { } val taskId = TaskContext.getPartitionId().toString val cacheDirName = if (useExternalMemory) { - val dir = new File(s"${TaskContext.get().stageId()}-cache-$taskId") - if (!(dir.exists() || dir.mkdirs())) { - throw new XGBoostError(s"failed to create cache directory: $dir") - } - Some(dir.toString) + val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId") + new File(dir.toUri).deleteOnExit() + Some(dir.toAbsolutePath.toString) } else { None } @@ -325,23 +324,24 @@ object XGBoost extends Serializable { case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" + " an instance of Long.") } - val (checkpointPath, savingFeq) = CheckpointManager.extractParams(params) + val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params) val partitionedData = repartitionForTraining(trainingData, nWorkers) val sc = trainingData.sparkContext val checkpointManager = new CheckpointManager(sc, checkpointPath) checkpointManager.cleanUpHigherVersions(round) - var prevBooster = checkpointManager.loadBooster + var prevBooster = checkpointManager.loadCheckpointAsBooster // Train for every ${savingRound} rounds and save the partially completed booster - checkpointManager.getSavingRounds(savingFeq, round).map { - savingRound: Int => + checkpointManager.getCheckpointRounds(checkpointInterval, round).map { + checkpointRound: Int => val tracker = startTracker(nWorkers, trackerConf) try { val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers) val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc) val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams, - tracker.getWorkerEnvs, savingRound, obj, eval, useExternalMemory, missing, prevBooster) + tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing, + prevBooster) val sparkJobThread = new Thread() { override def run() { // force the job @@ -359,9 +359,9 @@ object XGBoost extends Serializable { model.asInstanceOf[XGBoostClassificationModel].numOfClasses = params.getOrElse("num_class", "2").toString.toInt } - if (savingRound < round) { + if (checkpointRound < round) { prevBooster = model.booster - checkpointManager.updateModel(model) + checkpointManager.updateCheckpoint(model) } model } finally { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index 106514f96..90a943b49 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -71,7 +71,7 @@ trait GeneralParams extends Params { val missing = new FloatParam(this, "missing", "the value treated as missing") /** - * the interval to check whether total numCores is no smaller than nWorkers. default: 30 minutes + * the maximum time to wait for the job requesting new workers. default: 30 minutes */ val timeoutRequestWorkers = new LongParam(this, "timeout_request_workers", "the maximum time to" + " request new Workers if numCores are insufficient. The timeout will be disabled if this" + @@ -81,16 +81,19 @@ trait GeneralParams extends Params { * The hdfs folder to load and save checkpoint boosters. default: `empty_string` */ val checkpointPath = new Param[String](this, "checkpoint_path", "the hdfs folder to load and " + - "save checkpoints. The job will try to load the existing booster as the starting point for " + - "training. If saving_frequency is also set, the job will save a checkpoint every a few rounds.") + "save checkpoints. If there are existing checkpoints in checkpoint_path. The job will load " + + "the checkpoint with highest version as the starting point for training. If " + + "checkpoint_interval is also set, the job will save a checkpoint every a few rounds.") /** - * The frequency to save checkpoint boosters. default: 0 + * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that + * the trained model will get checkpointed every 10 iterations. Note: `checkpoint_path` must + * also be set if the checkpoint interval is greater than 0. */ - val savingFrequency = new IntParam(this, "saving_frequency", "if checkpoint_path is also set," + - " the job will save checkpoints at this frequency. If the job fails and gets restarted with" + - " same setting, it will load the existing booster instead of training from scratch." + - " Checkpoint will be disabled if set to 0.") + val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint " + + "interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained model will get " + + "checkpointed every 10 iterations. Note: `checkpoint_path` must also be set if the checkpoint" + + " interval is greater than 0.", (interval: Int) => interval == -1 || interval >= 1) /** * Rabit tracker configurations. The parameter must be provided as an instance of the @@ -128,6 +131,6 @@ trait GeneralParams extends Params { useExternalMemory -> false, silent -> 0, customObj -> null, customEval -> null, missing -> Float.NaN, trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L, - checkpointPath -> "", savingFrequency -> 0 + checkpointPath -> "", checkpointInterval -> -1 ) } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala index f0c9ba697..4a6feb2e4 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala @@ -45,23 +45,23 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll { test("test update/load models") { val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString val manager = new CheckpointManager(sc, tmpPath) - manager.updateModel(model4) + manager.updateCheckpoint(model4) var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) assert(files.head.getPath.getName == "4.model") - assert(manager.loadBooster.booster.getVersion == 4) + assert(manager.loadCheckpointAsBooster.booster.getVersion == 4) - manager.updateModel(model8) + manager.updateCheckpoint(model8) files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) assert(files.head.getPath.getName == "8.model") - assert(manager.loadBooster.booster.getVersion == 8) + assert(manager.loadCheckpointAsBooster.booster.getVersion == 8) } test("test cleanUpHigherVersions") { val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString val manager = new CheckpointManager(sc, tmpPath) - manager.updateModel(model8) + manager.updateCheckpoint(model8) manager.cleanUpHigherVersions(round = 8) assert(new File(s"$tmpPath/8.model").exists()) @@ -69,12 +69,12 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll { assert(!new File(s"$tmpPath/8.model").exists()) } - test("test saving rounds") { + test("test checkpoint rounds") { val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString val manager = new CheckpointManager(sc, tmpPath) - assertResult(Seq(7))(manager.getSavingRounds(savingFreq = 0, round = 7)) - assertResult(Seq(2, 4, 6, 7))(manager.getSavingRounds(savingFreq = 2, round = 7)) - manager.updateModel(model4) - assertResult(Seq(4, 6, 7))(manager.getSavingRounds(2, 7)) + assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7)) + assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7)) + manager.updateCheckpoint(model4) + assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7)) } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 053dff156..64f8b8ca2 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -338,7 +338,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { } } - test("training with saving checkpoint boosters") { + test("training with checkpoint boosters") { import DataUtils._ val eval = new EvalError() val trainingRDD = sc.parallelize(Classification.train).map(_.asML) @@ -347,7 +347,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1", "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, - "saving_frequency" -> 2).toMap + "checkpoint_interval" -> 2).toMap val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, nWorkers = numWorkers) def error(model: XGBoostModel): Float = eval.eval(