[jvm-packages] cleaning checkpoint file after a successful training (#4754)
* cleaning checkpoint file after a successful file * address comments
This commit is contained in:
@@ -53,6 +53,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
||||
}
|
||||
}
|
||||
|
||||
def cleanPath(): Unit = {
|
||||
if (checkpointPath != "") {
|
||||
FileSystem.get(sc.hadoopConfiguration).delete(new Path(checkpointPath), true)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load existing checkpoint with the highest version as a Booster object
|
||||
*
|
||||
@@ -127,7 +133,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
||||
|
||||
object CheckpointManager {
|
||||
|
||||
private[spark] def extractParams(params: Map[String, Any]): (String, Int) = {
|
||||
case class CheckpointParam(
|
||||
checkpointPath: String,
|
||||
checkpointInterval: Int,
|
||||
skipCleanCheckpoint: Boolean)
|
||||
|
||||
private[spark] def extractParams(params: Map[String, Any]): CheckpointParam = {
|
||||
val checkpointPath: String = params.get("checkpoint_path") match {
|
||||
case None => ""
|
||||
case Some(path: String) => path
|
||||
@@ -141,6 +152,13 @@ object CheckpointManager {
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
||||
" an instance of Int.")
|
||||
}
|
||||
(checkpointPath, checkpointInterval)
|
||||
|
||||
val skipCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
|
||||
case None => false
|
||||
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
|
||||
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
|
||||
" an instance of Boolean")
|
||||
}
|
||||
CheckpointParam(checkpointPath, checkpointInterval, skipCheckpointFile)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -331,9 +331,11 @@ object XGBoost extends Serializable {
|
||||
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
|
||||
" an instance of Long.")
|
||||
}
|
||||
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
|
||||
val checkpointParam =
|
||||
CheckpointManager.extractParams(params)
|
||||
(nWorkers, round, useExternalMemory, obj, eval, missing, trackerConf, timeoutRequestWorkers,
|
||||
checkpointPath, checkpointInterval)
|
||||
checkpointParam.checkpointPath, checkpointParam.checkpointInterval,
|
||||
checkpointParam.skipCleanCheckpoint)
|
||||
}
|
||||
|
||||
private def trainForNonRanking(
|
||||
@@ -343,7 +345,7 @@ object XGBoost extends Serializable {
|
||||
checkpointRound: Int,
|
||||
prevBooster: Booster,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) =
|
||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||
if (evalSetsMap.isEmpty) {
|
||||
trainingData.mapPartitions(labeledPoints => {
|
||||
@@ -373,7 +375,7 @@ object XGBoost extends Serializable {
|
||||
checkpointRound: Int,
|
||||
prevBooster: Booster,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) =
|
||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||
if (evalSetsMap.isEmpty) {
|
||||
trainingData.mapPartitions(labeledPointGroups => {
|
||||
@@ -427,7 +429,8 @@ object XGBoost extends Serializable {
|
||||
(Booster, Map[String, Array[Float]]) = {
|
||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||
val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers,
|
||||
checkpointPath, checkpointInterval) = parameterFetchAndValidation(params,
|
||||
checkpointPath, checkpointInterval, skipCleanCheckpoint) =
|
||||
parameterFetchAndValidation(params,
|
||||
trainingData.sparkContext)
|
||||
val sc = trainingData.sparkContext
|
||||
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
||||
@@ -437,7 +440,7 @@ object XGBoost extends Serializable {
|
||||
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
||||
try {
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
||||
val producedBooster = checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
||||
checkpointRound: Int =>
|
||||
val tracker = startTracker(nWorkers, trackerConf)
|
||||
try {
|
||||
@@ -473,6 +476,11 @@ object XGBoost extends Serializable {
|
||||
tracker.stop()
|
||||
}
|
||||
}.last
|
||||
// we should delete the checkpoint directory after a successful training
|
||||
if (!skipCleanCheckpoint) {
|
||||
checkpointManager.cleanPath()
|
||||
}
|
||||
producedBooster
|
||||
} catch {
|
||||
case t: Throwable =>
|
||||
// if the job was aborted due to an exception
|
||||
|
||||
@@ -80,6 +80,13 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
final val cacheTrainingSet = new BooleanParam(this, "cacheTrainingSet",
|
||||
"whether caching training data")
|
||||
|
||||
/**
|
||||
* whether cleaning checkpoint, always cleaning by default, having this parameter majorly for
|
||||
* testing
|
||||
*/
|
||||
final val skipCleanCheckpoint = new BooleanParam(this, "skipCleanCheckpoint",
|
||||
"whether cleaning checkpoint data")
|
||||
|
||||
/**
|
||||
* If non-zero, the training will be stopped after a specified number
|
||||
* of consecutive increases in any evaluation metric.
|
||||
|
||||
Reference in New Issue
Block a user