[jvm-packages] Update docs and unify the terminology (#3024)

* [jvm-packages] Move cache files to tmp dir and delete on exit

* [jvm-packages] Update docs and unify terminology

* Address CR Comments
This commit is contained in:
Yun Ni 2018-01-16 08:16:55 -08:00 committed by Sergei Lebedev
parent 84ab74f3a5
commit 3f3f54bcad
5 changed files with 65 additions and 58 deletions

View File

@ -22,12 +22,16 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
/**
* A class which allows user to save checkpoint boosters every a few rounds. If a previous job
* fails, the job can restart training from a saved booster instead of from scratch. This class
* A class which allows user to save checkpoints every a few rounds. If a previous job fails,
* the job can restart training from a saved checkpoints instead of from scratch. This class
* provides interface and helper methods for the checkpoint functionality.
*
* NOTE: This checkpoint is different from Rabit checkpoint. Rabit checkpoint is a native-level
* checkpoint stored in executor memory. This is a checkpoint which Spark driver store on HDFS
* for every a few iterations.
*
* @param sc the sparkContext object
* @param checkpointPath the hdfs path to store checkpoint boosters
* @param checkpointPath the hdfs path to store checkpoints
*/
private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) {
private val logger = LogFactory.getLog("XGBoostSpark")
@ -49,11 +53,11 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
}
/**
* Load existing checkpoint with the highest version.
* Load existing checkpoint with the highest version as a Booster object
*
* @return the booster with the highest version, null if no checkpoints available.
*/
private[spark] def loadBooster: Booster = {
private[spark] def loadCheckpointAsBooster: Booster = {
val versions = getExistingVersions
if (versions.nonEmpty) {
val version = versions.max
@ -68,16 +72,16 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
}
/**
* Clean up all previous models and save a new model
* Clean up all previous checkpoints and save a new checkpoint
*
* @param model the xgboost model to save
* @param checkpoint the checkpoint to save as an XGBoostModel
*/
private[spark] def updateModel(model: XGBoostModel): Unit = {
private[spark] def updateCheckpoint(checkpoint: XGBoostModel): Unit = {
val fs = FileSystem.get(sc.hadoopConfiguration)
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
val fullPath = getPath(model.version)
logger.info(s"Saving checkpoint model with version ${model.version} to $fullPath")
model.saveModelAsHadoopFile(fullPath)(sc)
val fullPath = getPath(checkpoint.version)
logger.info(s"Saving checkpoint model with version ${checkpoint.version} to $fullPath")
checkpoint.saveModelAsHadoopFile(fullPath)(sc)
prevModelPaths.foreach(path => fs.delete(path, true))
}
@ -95,22 +99,22 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
}
/**
* Calculate a list of checkpoint rounds to save checkpoints based on the savingFreq and
* total number of rounds for the training. Concretely, the saving rounds start with
* prevRounds + savingFreq, and increase by savingFreq in each step until it reaches total
* number of rounds. If savingFreq is 0, the checkpoint will be disabled and the method
* returns Seq(round)
* Calculate a list of checkpoint rounds to save checkpoints based on the checkpointInterval
* and total number of rounds for the training. Concretely, the checkpoint rounds start with
* prevRounds + checkpointInterval, and increase by checkpointInterval in each step until it
* reaches total number of rounds. If checkpointInterval is 0, the checkpoint will be disabled
* and the method returns Seq(round)
*
* @param savingFreq the increase on rounds during each step of training
* @param checkpointInterval Period (in iterations) between checkpoints.
* @param round the total number of rounds for the training
* @return a seq of integers, each represent the index of round to save the checkpoints
*/
private[spark] def getSavingRounds(savingFreq: Int, round: Int): Seq[Int] = {
if (checkpointPath.nonEmpty && savingFreq > 0) {
private[spark] def getCheckpointRounds(checkpointInterval: Int, round: Int): Seq[Int] = {
if (checkpointPath.nonEmpty && checkpointInterval > 0) {
val prevRounds = getExistingVersions.map(_ / 2)
val firstSavingRound = (0 +: prevRounds).max + savingFreq
(firstSavingRound until round by savingFreq) :+ round
} else if (savingFreq <= 0) {
val firstCheckpointRound = (0 +: prevRounds).max + checkpointInterval
(firstCheckpointRound until round by checkpointInterval) :+ round
} else if (checkpointInterval <= 0) {
Seq(round)
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.")
@ -128,12 +132,12 @@ object CheckpointManager {
" an instance of String.")
}
val savingFreq: Int = params.get("saving_frequency") match {
val checkpointInterval: Int = params.get("checkpoint_interval") match {
case None => 0
case Some(freq: Int) => freq
case _ => throw new IllegalArgumentException("parameter \"saving_frequency\" must be" +
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
" an instance of Int.")
}
(checkpointPath, savingFreq)
(checkpointPath, checkpointInterval)
}
}

View File

@ -17,6 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.nio.file.Files
import scala.collection.mutable
import scala.util.Random
@ -120,11 +121,9 @@ object XGBoost extends Serializable {
}
val taskId = TaskContext.getPartitionId().toString
val cacheDirName = if (useExternalMemory) {
val dir = new File(s"${TaskContext.get().stageId()}-cache-$taskId")
if (!(dir.exists() || dir.mkdirs())) {
throw new XGBoostError(s"failed to create cache directory: $dir")
}
Some(dir.toString)
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
new File(dir.toUri).deleteOnExit()
Some(dir.toAbsolutePath.toString)
} else {
None
}
@ -325,23 +324,24 @@ object XGBoost extends Serializable {
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
" an instance of Long.")
}
val (checkpointPath, savingFeq) = CheckpointManager.extractParams(params)
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
val partitionedData = repartitionForTraining(trainingData, nWorkers)
val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, checkpointPath)
checkpointManager.cleanUpHigherVersions(round)
var prevBooster = checkpointManager.loadBooster
var prevBooster = checkpointManager.loadCheckpointAsBooster
// Train for every ${savingRound} rounds and save the partially completed booster
checkpointManager.getSavingRounds(savingFeq, round).map {
savingRound: Int =>
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
checkpointRound: Int =>
val tracker = startTracker(nWorkers, trackerConf)
try {
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
tracker.getWorkerEnvs, savingRound, obj, eval, useExternalMemory, missing, prevBooster)
tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing,
prevBooster)
val sparkJobThread = new Thread() {
override def run() {
// force the job
@ -359,9 +359,9 @@ object XGBoost extends Serializable {
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
params.getOrElse("num_class", "2").toString.toInt
}
if (savingRound < round) {
if (checkpointRound < round) {
prevBooster = model.booster
checkpointManager.updateModel(model)
checkpointManager.updateCheckpoint(model)
}
model
} finally {

View File

@ -71,7 +71,7 @@ trait GeneralParams extends Params {
val missing = new FloatParam(this, "missing", "the value treated as missing")
/**
* the interval to check whether total numCores is no smaller than nWorkers. default: 30 minutes
* the maximum time to wait for the job requesting new workers. default: 30 minutes
*/
val timeoutRequestWorkers = new LongParam(this, "timeout_request_workers", "the maximum time to" +
" request new Workers if numCores are insufficient. The timeout will be disabled if this" +
@ -81,16 +81,19 @@ trait GeneralParams extends Params {
* The hdfs folder to load and save checkpoint boosters. default: `empty_string`
*/
val checkpointPath = new Param[String](this, "checkpoint_path", "the hdfs folder to load and " +
"save checkpoints. The job will try to load the existing booster as the starting point for " +
"training. If saving_frequency is also set, the job will save a checkpoint every a few rounds.")
"save checkpoints. If there are existing checkpoints in checkpoint_path. The job will load " +
"the checkpoint with highest version as the starting point for training. If " +
"checkpoint_interval is also set, the job will save a checkpoint every a few rounds.")
/**
* The frequency to save checkpoint boosters. default: 0
* Param for set checkpoint interval (&gt;= 1) or disable checkpoint (-1). E.g. 10 means that
* the trained model will get checkpointed every 10 iterations. Note: `checkpoint_path` must
* also be set if the checkpoint interval is greater than 0.
*/
val savingFrequency = new IntParam(this, "saving_frequency", "if checkpoint_path is also set," +
" the job will save checkpoints at this frequency. If the job fails and gets restarted with" +
" same setting, it will load the existing booster instead of training from scratch." +
" Checkpoint will be disabled if set to 0.")
val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint " +
"interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained model will get " +
"checkpointed every 10 iterations. Note: `checkpoint_path` must also be set if the checkpoint" +
" interval is greater than 0.", (interval: Int) => interval == -1 || interval >= 1)
/**
* Rabit tracker configurations. The parameter must be provided as an instance of the
@ -128,6 +131,6 @@ trait GeneralParams extends Params {
useExternalMemory -> false, silent -> 0,
customObj -> null, customEval -> null, missing -> Float.NaN,
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
checkpointPath -> "", savingFrequency -> 0
checkpointPath -> "", checkpointInterval -> -1
)
}

View File

@ -45,23 +45,23 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
test("test update/load models") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateModel(model4)
manager.updateCheckpoint(model4)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadBooster.booster.getVersion == 4)
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
manager.updateModel(model8)
manager.updateCheckpoint(model8)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
assert(manager.loadBooster.booster.getVersion == 8)
assert(manager.loadCheckpointAsBooster.booster.getVersion == 8)
}
test("test cleanUpHigherVersions") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateModel(model8)
manager.updateCheckpoint(model8)
manager.cleanUpHigherVersions(round = 8)
assert(new File(s"$tmpPath/8.model").exists())
@ -69,12 +69,12 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
assert(!new File(s"$tmpPath/8.model").exists())
}
test("test saving rounds") {
test("test checkpoint rounds") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
assertResult(Seq(7))(manager.getSavingRounds(savingFreq = 0, round = 7))
assertResult(Seq(2, 4, 6, 7))(manager.getSavingRounds(savingFreq = 2, round = 7))
manager.updateModel(model4)
assertResult(Seq(4, 6, 7))(manager.getSavingRounds(2, 7))
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
manager.updateCheckpoint(model4)
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
}
}

View File

@ -338,7 +338,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
}
}
test("training with saving checkpoint boosters") {
test("training with checkpoint boosters") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
@ -347,7 +347,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"saving_frequency" -> 2).toMap
"checkpoint_interval" -> 2).toMap
val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers)
def error(model: XGBoostModel): Float = eval.eval(