[jvm-packages] cleaning checkpoint file after a successful training (#4754)
* cleaning checkpoint file after a successful file * address comments
This commit is contained in:
parent
ef9af33a00
commit
7b5cbcc846
@ -53,6 +53,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
||||
}
|
||||
}
|
||||
|
||||
def cleanPath(): Unit = {
|
||||
if (checkpointPath != "") {
|
||||
FileSystem.get(sc.hadoopConfiguration).delete(new Path(checkpointPath), true)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load existing checkpoint with the highest version as a Booster object
|
||||
*
|
||||
@ -127,7 +133,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
||||
|
||||
object CheckpointManager {
|
||||
|
||||
private[spark] def extractParams(params: Map[String, Any]): (String, Int) = {
|
||||
case class CheckpointParam(
|
||||
checkpointPath: String,
|
||||
checkpointInterval: Int,
|
||||
skipCleanCheckpoint: Boolean)
|
||||
|
||||
private[spark] def extractParams(params: Map[String, Any]): CheckpointParam = {
|
||||
val checkpointPath: String = params.get("checkpoint_path") match {
|
||||
case None => ""
|
||||
case Some(path: String) => path
|
||||
@ -141,6 +152,13 @@ object CheckpointManager {
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
||||
" an instance of Int.")
|
||||
}
|
||||
(checkpointPath, checkpointInterval)
|
||||
|
||||
val skipCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
|
||||
case None => false
|
||||
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
|
||||
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
|
||||
" an instance of Boolean")
|
||||
}
|
||||
CheckpointParam(checkpointPath, checkpointInterval, skipCheckpointFile)
|
||||
}
|
||||
}
|
||||
|
||||
@ -331,9 +331,11 @@ object XGBoost extends Serializable {
|
||||
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
|
||||
" an instance of Long.")
|
||||
}
|
||||
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
|
||||
val checkpointParam =
|
||||
CheckpointManager.extractParams(params)
|
||||
(nWorkers, round, useExternalMemory, obj, eval, missing, trackerConf, timeoutRequestWorkers,
|
||||
checkpointPath, checkpointInterval)
|
||||
checkpointParam.checkpointPath, checkpointParam.checkpointInterval,
|
||||
checkpointParam.skipCleanCheckpoint)
|
||||
}
|
||||
|
||||
private def trainForNonRanking(
|
||||
@ -343,7 +345,7 @@ object XGBoost extends Serializable {
|
||||
checkpointRound: Int,
|
||||
prevBooster: Booster,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) =
|
||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||
if (evalSetsMap.isEmpty) {
|
||||
trainingData.mapPartitions(labeledPoints => {
|
||||
@ -373,7 +375,7 @@ object XGBoost extends Serializable {
|
||||
checkpointRound: Int,
|
||||
prevBooster: Booster,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) =
|
||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||
if (evalSetsMap.isEmpty) {
|
||||
trainingData.mapPartitions(labeledPointGroups => {
|
||||
@ -427,7 +429,8 @@ object XGBoost extends Serializable {
|
||||
(Booster, Map[String, Array[Float]]) = {
|
||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||
val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers,
|
||||
checkpointPath, checkpointInterval) = parameterFetchAndValidation(params,
|
||||
checkpointPath, checkpointInterval, skipCleanCheckpoint) =
|
||||
parameterFetchAndValidation(params,
|
||||
trainingData.sparkContext)
|
||||
val sc = trainingData.sparkContext
|
||||
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
||||
@ -437,7 +440,7 @@ object XGBoost extends Serializable {
|
||||
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
||||
try {
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
||||
val producedBooster = checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
||||
checkpointRound: Int =>
|
||||
val tracker = startTracker(nWorkers, trackerConf)
|
||||
try {
|
||||
@ -473,6 +476,11 @@ object XGBoost extends Serializable {
|
||||
tracker.stop()
|
||||
}
|
||||
}.last
|
||||
// we should delete the checkpoint directory after a successful training
|
||||
if (!skipCleanCheckpoint) {
|
||||
checkpointManager.cleanPath()
|
||||
}
|
||||
producedBooster
|
||||
} catch {
|
||||
case t: Throwable =>
|
||||
// if the job was aborted due to an exception
|
||||
|
||||
@ -80,6 +80,13 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
final val cacheTrainingSet = new BooleanParam(this, "cacheTrainingSet",
|
||||
"whether caching training data")
|
||||
|
||||
/**
|
||||
* whether cleaning checkpoint, always cleaning by default, having this parameter majorly for
|
||||
* testing
|
||||
*/
|
||||
final val skipCleanCheckpoint = new BooleanParam(this, "skipCleanCheckpoint",
|
||||
"whether cleaning checkpoint data")
|
||||
|
||||
/**
|
||||
* If non-zero, the training will be stopped after a specified number
|
||||
* of consecutive increases in any evaluation metric.
|
||||
|
||||
@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||
import org.scalatest.FunSuite
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
@ -67,4 +68,50 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
|
||||
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
|
||||
}
|
||||
|
||||
|
||||
private def trainingWithCheckpoint(cacheData: Boolean, skipCleanCheckpoint: Boolean): Unit = {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
|
||||
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
||||
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
|
||||
val skipCleanCheckpointMap =
|
||||
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++
|
||||
skipCleanCheckpointMap
|
||||
|
||||
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
||||
def error(model: Booster): Float = eval.eval(
|
||||
model.predict(testDM, outPutMargin = true), testDM)
|
||||
|
||||
if (skipCleanCheckpoint) {
|
||||
// Check only one model is kept after training
|
||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||
// Train next model based on prev model
|
||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||
assert(error(tmpModel) > error(prevModel._booster))
|
||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
||||
assert(error(nextModel._booster) < 0.1)
|
||||
} else {
|
||||
assert(!FileSystem.get(sc.hadoopConfiguration).exists(new Path(tmpPath)))
|
||||
}
|
||||
}
|
||||
|
||||
test("training with checkpoint boosters") {
|
||||
trainingWithCheckpoint(cacheData = false, skipCleanCheckpoint = true)
|
||||
}
|
||||
|
||||
test("training with checkpoint boosters with cached training dataset") {
|
||||
trainingWithCheckpoint(cacheData = true, skipCleanCheckpoint = true)
|
||||
}
|
||||
|
||||
test("the checkpoint file should be cleaned after a successful training") {
|
||||
trainingWithCheckpoint(cacheData = false, skipCleanCheckpoint = false)
|
||||
}
|
||||
}
|
||||
|
||||
@ -179,60 +179,6 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
test("training with checkpoint boosters") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
|
||||
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
|
||||
|
||||
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
||||
def error(model: Booster): Float = eval.eval(
|
||||
model.predict(testDM, outPutMargin = true), testDM)
|
||||
|
||||
// Check only one model is kept after training
|
||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||
|
||||
// Train next model based on prev model
|
||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||
assert(error(tmpModel) > error(prevModel._booster))
|
||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
||||
assert(error(nextModel._booster) < 0.1)
|
||||
}
|
||||
|
||||
test("training with checkpoint boosters with cached training dataset") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
|
||||
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers, "cacheTrainingSet" -> true)
|
||||
|
||||
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
||||
def error(model: Booster): Float = eval.eval(
|
||||
model.predict(testDM, outPutMargin = true), testDM)
|
||||
|
||||
// Check only one model is kept after training
|
||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||
|
||||
// Train next model based on prev model
|
||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||
assert(error(tmpModel) > error(prevModel._booster))
|
||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
||||
assert(error(nextModel._booster) < 0.1)
|
||||
}
|
||||
|
||||
test("repartitionForTrainingGroup with group data") {
|
||||
// test different splits to cover the corner cases.
|
||||
for (split <- 1 to 20) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user