[jvm-packages] Update docs and unify the terminology (#3024)
* [jvm-packages] Move cache files to tmp dir and delete on exit * [jvm-packages] Update docs and unify terminology * Address CR Comments
This commit is contained in:
parent
84ab74f3a5
commit
3f3f54bcad
@ -22,12 +22,16 @@ import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.spark.SparkContext
|
||||
|
||||
/**
|
||||
* A class which allows user to save checkpoint boosters every a few rounds. If a previous job
|
||||
* fails, the job can restart training from a saved booster instead of from scratch. This class
|
||||
* A class which allows user to save checkpoints every a few rounds. If a previous job fails,
|
||||
* the job can restart training from a saved checkpoints instead of from scratch. This class
|
||||
* provides interface and helper methods for the checkpoint functionality.
|
||||
*
|
||||
* NOTE: This checkpoint is different from Rabit checkpoint. Rabit checkpoint is a native-level
|
||||
* checkpoint stored in executor memory. This is a checkpoint which Spark driver store on HDFS
|
||||
* for every a few iterations.
|
||||
*
|
||||
* @param sc the sparkContext object
|
||||
* @param checkpointPath the hdfs path to store checkpoint boosters
|
||||
* @param checkpointPath the hdfs path to store checkpoints
|
||||
*/
|
||||
private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
@ -49,11 +53,11 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
||||
}
|
||||
|
||||
/**
|
||||
* Load existing checkpoint with the highest version.
|
||||
* Load existing checkpoint with the highest version as a Booster object
|
||||
*
|
||||
* @return the booster with the highest version, null if no checkpoints available.
|
||||
*/
|
||||
private[spark] def loadBooster: Booster = {
|
||||
private[spark] def loadCheckpointAsBooster: Booster = {
|
||||
val versions = getExistingVersions
|
||||
if (versions.nonEmpty) {
|
||||
val version = versions.max
|
||||
@ -68,16 +72,16 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up all previous models and save a new model
|
||||
* Clean up all previous checkpoints and save a new checkpoint
|
||||
*
|
||||
* @param model the xgboost model to save
|
||||
* @param checkpoint the checkpoint to save as an XGBoostModel
|
||||
*/
|
||||
private[spark] def updateModel(model: XGBoostModel): Unit = {
|
||||
private[spark] def updateCheckpoint(checkpoint: XGBoostModel): Unit = {
|
||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
||||
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
|
||||
val fullPath = getPath(model.version)
|
||||
logger.info(s"Saving checkpoint model with version ${model.version} to $fullPath")
|
||||
model.saveModelAsHadoopFile(fullPath)(sc)
|
||||
val fullPath = getPath(checkpoint.version)
|
||||
logger.info(s"Saving checkpoint model with version ${checkpoint.version} to $fullPath")
|
||||
checkpoint.saveModelAsHadoopFile(fullPath)(sc)
|
||||
prevModelPaths.foreach(path => fs.delete(path, true))
|
||||
}
|
||||
|
||||
@ -95,22 +99,22 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate a list of checkpoint rounds to save checkpoints based on the savingFreq and
|
||||
* total number of rounds for the training. Concretely, the saving rounds start with
|
||||
* prevRounds + savingFreq, and increase by savingFreq in each step until it reaches total
|
||||
* number of rounds. If savingFreq is 0, the checkpoint will be disabled and the method
|
||||
* returns Seq(round)
|
||||
* Calculate a list of checkpoint rounds to save checkpoints based on the checkpointInterval
|
||||
* and total number of rounds for the training. Concretely, the checkpoint rounds start with
|
||||
* prevRounds + checkpointInterval, and increase by checkpointInterval in each step until it
|
||||
* reaches total number of rounds. If checkpointInterval is 0, the checkpoint will be disabled
|
||||
* and the method returns Seq(round)
|
||||
*
|
||||
* @param savingFreq the increase on rounds during each step of training
|
||||
* @param checkpointInterval Period (in iterations) between checkpoints.
|
||||
* @param round the total number of rounds for the training
|
||||
* @return a seq of integers, each represent the index of round to save the checkpoints
|
||||
*/
|
||||
private[spark] def getSavingRounds(savingFreq: Int, round: Int): Seq[Int] = {
|
||||
if (checkpointPath.nonEmpty && savingFreq > 0) {
|
||||
private[spark] def getCheckpointRounds(checkpointInterval: Int, round: Int): Seq[Int] = {
|
||||
if (checkpointPath.nonEmpty && checkpointInterval > 0) {
|
||||
val prevRounds = getExistingVersions.map(_ / 2)
|
||||
val firstSavingRound = (0 +: prevRounds).max + savingFreq
|
||||
(firstSavingRound until round by savingFreq) :+ round
|
||||
} else if (savingFreq <= 0) {
|
||||
val firstCheckpointRound = (0 +: prevRounds).max + checkpointInterval
|
||||
(firstCheckpointRound until round by checkpointInterval) :+ round
|
||||
} else if (checkpointInterval <= 0) {
|
||||
Seq(round)
|
||||
} else {
|
||||
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.")
|
||||
@ -128,12 +132,12 @@ object CheckpointManager {
|
||||
" an instance of String.")
|
||||
}
|
||||
|
||||
val savingFreq: Int = params.get("saving_frequency") match {
|
||||
val checkpointInterval: Int = params.get("checkpoint_interval") match {
|
||||
case None => 0
|
||||
case Some(freq: Int) => freq
|
||||
case _ => throw new IllegalArgumentException("parameter \"saving_frequency\" must be" +
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
||||
" an instance of Int.")
|
||||
}
|
||||
(checkpointPath, savingFreq)
|
||||
(checkpointPath, checkpointInterval)
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
import java.nio.file.Files
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
@ -120,11 +121,9 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
val cacheDirName = if (useExternalMemory) {
|
||||
val dir = new File(s"${TaskContext.get().stageId()}-cache-$taskId")
|
||||
if (!(dir.exists() || dir.mkdirs())) {
|
||||
throw new XGBoostError(s"failed to create cache directory: $dir")
|
||||
}
|
||||
Some(dir.toString)
|
||||
val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
|
||||
new File(dir.toUri).deleteOnExit()
|
||||
Some(dir.toAbsolutePath.toString)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@ -325,23 +324,24 @@ object XGBoost extends Serializable {
|
||||
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
|
||||
" an instance of Long.")
|
||||
}
|
||||
val (checkpointPath, savingFeq) = CheckpointManager.extractParams(params)
|
||||
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
|
||||
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||
|
||||
val sc = trainingData.sparkContext
|
||||
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
||||
checkpointManager.cleanUpHigherVersions(round)
|
||||
|
||||
var prevBooster = checkpointManager.loadBooster
|
||||
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
checkpointManager.getSavingRounds(savingFeq, round).map {
|
||||
savingRound: Int =>
|
||||
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
||||
checkpointRound: Int =>
|
||||
val tracker = startTracker(nWorkers, trackerConf)
|
||||
try {
|
||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
||||
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
|
||||
tracker.getWorkerEnvs, savingRound, obj, eval, useExternalMemory, missing, prevBooster)
|
||||
tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing,
|
||||
prevBooster)
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
@ -359,9 +359,9 @@ object XGBoost extends Serializable {
|
||||
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
|
||||
params.getOrElse("num_class", "2").toString.toInt
|
||||
}
|
||||
if (savingRound < round) {
|
||||
if (checkpointRound < round) {
|
||||
prevBooster = model.booster
|
||||
checkpointManager.updateModel(model)
|
||||
checkpointManager.updateCheckpoint(model)
|
||||
}
|
||||
model
|
||||
} finally {
|
||||
|
||||
@ -71,7 +71,7 @@ trait GeneralParams extends Params {
|
||||
val missing = new FloatParam(this, "missing", "the value treated as missing")
|
||||
|
||||
/**
|
||||
* the interval to check whether total numCores is no smaller than nWorkers. default: 30 minutes
|
||||
* the maximum time to wait for the job requesting new workers. default: 30 minutes
|
||||
*/
|
||||
val timeoutRequestWorkers = new LongParam(this, "timeout_request_workers", "the maximum time to" +
|
||||
" request new Workers if numCores are insufficient. The timeout will be disabled if this" +
|
||||
@ -81,16 +81,19 @@ trait GeneralParams extends Params {
|
||||
* The hdfs folder to load and save checkpoint boosters. default: `empty_string`
|
||||
*/
|
||||
val checkpointPath = new Param[String](this, "checkpoint_path", "the hdfs folder to load and " +
|
||||
"save checkpoints. The job will try to load the existing booster as the starting point for " +
|
||||
"training. If saving_frequency is also set, the job will save a checkpoint every a few rounds.")
|
||||
"save checkpoints. If there are existing checkpoints in checkpoint_path. The job will load " +
|
||||
"the checkpoint with highest version as the starting point for training. If " +
|
||||
"checkpoint_interval is also set, the job will save a checkpoint every a few rounds.")
|
||||
|
||||
/**
|
||||
* The frequency to save checkpoint boosters. default: 0
|
||||
* Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that
|
||||
* the trained model will get checkpointed every 10 iterations. Note: `checkpoint_path` must
|
||||
* also be set if the checkpoint interval is greater than 0.
|
||||
*/
|
||||
val savingFrequency = new IntParam(this, "saving_frequency", "if checkpoint_path is also set," +
|
||||
" the job will save checkpoints at this frequency. If the job fails and gets restarted with" +
|
||||
" same setting, it will load the existing booster instead of training from scratch." +
|
||||
" Checkpoint will be disabled if set to 0.")
|
||||
val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint " +
|
||||
"interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained model will get " +
|
||||
"checkpointed every 10 iterations. Note: `checkpoint_path` must also be set if the checkpoint" +
|
||||
" interval is greater than 0.", (interval: Int) => interval == -1 || interval >= 1)
|
||||
|
||||
/**
|
||||
* Rabit tracker configurations. The parameter must be provided as an instance of the
|
||||
@ -128,6 +131,6 @@ trait GeneralParams extends Params {
|
||||
useExternalMemory -> false, silent -> 0,
|
||||
customObj -> null, customEval -> null, missing -> Float.NaN,
|
||||
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
|
||||
checkpointPath -> "", savingFrequency -> 0
|
||||
checkpointPath -> "", checkpointInterval -> -1
|
||||
)
|
||||
}
|
||||
|
||||
@ -45,23 +45,23 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
|
||||
test("test update/load models") {
|
||||
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
manager.updateModel(model4)
|
||||
manager.updateCheckpoint(model4)
|
||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "4.model")
|
||||
assert(manager.loadBooster.booster.getVersion == 4)
|
||||
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
|
||||
|
||||
manager.updateModel(model8)
|
||||
manager.updateCheckpoint(model8)
|
||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
assert(manager.loadBooster.booster.getVersion == 8)
|
||||
assert(manager.loadCheckpointAsBooster.booster.getVersion == 8)
|
||||
}
|
||||
|
||||
test("test cleanUpHigherVersions") {
|
||||
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
manager.updateModel(model8)
|
||||
manager.updateCheckpoint(model8)
|
||||
manager.cleanUpHigherVersions(round = 8)
|
||||
assert(new File(s"$tmpPath/8.model").exists())
|
||||
|
||||
@ -69,12 +69,12 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
|
||||
assert(!new File(s"$tmpPath/8.model").exists())
|
||||
}
|
||||
|
||||
test("test saving rounds") {
|
||||
test("test checkpoint rounds") {
|
||||
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
assertResult(Seq(7))(manager.getSavingRounds(savingFreq = 0, round = 7))
|
||||
assertResult(Seq(2, 4, 6, 7))(manager.getSavingRounds(savingFreq = 2, round = 7))
|
||||
manager.updateModel(model4)
|
||||
assertResult(Seq(4, 6, 7))(manager.getSavingRounds(2, 7))
|
||||
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
|
||||
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
|
||||
manager.updateCheckpoint(model4)
|
||||
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
|
||||
}
|
||||
}
|
||||
|
||||
@ -338,7 +338,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
}
|
||||
}
|
||||
|
||||
test("training with saving checkpoint boosters") {
|
||||
test("training with checkpoint boosters") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
@ -347,7 +347,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||
"saving_frequency" -> 2).toMap
|
||||
"checkpoint_interval" -> 2).toMap
|
||||
val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers)
|
||||
def error(model: XGBoostModel): Float = eval.eval(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user