[jvm-packages] cleaning checkpoint file after a successful training (#4754)
* cleaning checkpoint file after a successful file * address comments
This commit is contained in:
parent
ef9af33a00
commit
7b5cbcc846
@ -53,6 +53,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def cleanPath(): Unit = {
|
||||||
|
if (checkpointPath != "") {
|
||||||
|
FileSystem.get(sc.hadoopConfiguration).delete(new Path(checkpointPath), true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load existing checkpoint with the highest version as a Booster object
|
* Load existing checkpoint with the highest version as a Booster object
|
||||||
*
|
*
|
||||||
@ -127,7 +133,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
|||||||
|
|
||||||
object CheckpointManager {
|
object CheckpointManager {
|
||||||
|
|
||||||
private[spark] def extractParams(params: Map[String, Any]): (String, Int) = {
|
case class CheckpointParam(
|
||||||
|
checkpointPath: String,
|
||||||
|
checkpointInterval: Int,
|
||||||
|
skipCleanCheckpoint: Boolean)
|
||||||
|
|
||||||
|
private[spark] def extractParams(params: Map[String, Any]): CheckpointParam = {
|
||||||
val checkpointPath: String = params.get("checkpoint_path") match {
|
val checkpointPath: String = params.get("checkpoint_path") match {
|
||||||
case None => ""
|
case None => ""
|
||||||
case Some(path: String) => path
|
case Some(path: String) => path
|
||||||
@ -141,6 +152,13 @@ object CheckpointManager {
|
|||||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
||||||
" an instance of Int.")
|
" an instance of Int.")
|
||||||
}
|
}
|
||||||
(checkpointPath, checkpointInterval)
|
|
||||||
|
val skipCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
|
||||||
|
case None => false
|
||||||
|
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
|
||||||
|
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
|
||||||
|
" an instance of Boolean")
|
||||||
|
}
|
||||||
|
CheckpointParam(checkpointPath, checkpointInterval, skipCheckpointFile)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -331,9 +331,11 @@ object XGBoost extends Serializable {
|
|||||||
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
|
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
|
||||||
" an instance of Long.")
|
" an instance of Long.")
|
||||||
}
|
}
|
||||||
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
|
val checkpointParam =
|
||||||
|
CheckpointManager.extractParams(params)
|
||||||
(nWorkers, round, useExternalMemory, obj, eval, missing, trackerConf, timeoutRequestWorkers,
|
(nWorkers, round, useExternalMemory, obj, eval, missing, trackerConf, timeoutRequestWorkers,
|
||||||
checkpointPath, checkpointInterval)
|
checkpointParam.checkpointPath, checkpointParam.checkpointInterval,
|
||||||
|
checkpointParam.skipCleanCheckpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def trainForNonRanking(
|
private def trainForNonRanking(
|
||||||
@ -343,7 +345,7 @@ object XGBoost extends Serializable {
|
|||||||
checkpointRound: Int,
|
checkpointRound: Int,
|
||||||
prevBooster: Booster,
|
prevBooster: Booster,
|
||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) =
|
||||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
trainingData.mapPartitions(labeledPoints => {
|
trainingData.mapPartitions(labeledPoints => {
|
||||||
@ -373,7 +375,7 @@ object XGBoost extends Serializable {
|
|||||||
checkpointRound: Int,
|
checkpointRound: Int,
|
||||||
prevBooster: Booster,
|
prevBooster: Booster,
|
||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) =
|
||||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
trainingData.mapPartitions(labeledPointGroups => {
|
trainingData.mapPartitions(labeledPointGroups => {
|
||||||
@ -427,7 +429,8 @@ object XGBoost extends Serializable {
|
|||||||
(Booster, Map[String, Array[Float]]) = {
|
(Booster, Map[String, Array[Float]]) = {
|
||||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||||
val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers,
|
val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers,
|
||||||
checkpointPath, checkpointInterval) = parameterFetchAndValidation(params,
|
checkpointPath, checkpointInterval, skipCleanCheckpoint) =
|
||||||
|
parameterFetchAndValidation(params,
|
||||||
trainingData.sparkContext)
|
trainingData.sparkContext)
|
||||||
val sc = trainingData.sparkContext
|
val sc = trainingData.sparkContext
|
||||||
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
||||||
@ -437,7 +440,7 @@ object XGBoost extends Serializable {
|
|||||||
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
||||||
try {
|
try {
|
||||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||||
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
val producedBooster = checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
||||||
checkpointRound: Int =>
|
checkpointRound: Int =>
|
||||||
val tracker = startTracker(nWorkers, trackerConf)
|
val tracker = startTracker(nWorkers, trackerConf)
|
||||||
try {
|
try {
|
||||||
@ -473,6 +476,11 @@ object XGBoost extends Serializable {
|
|||||||
tracker.stop()
|
tracker.stop()
|
||||||
}
|
}
|
||||||
}.last
|
}.last
|
||||||
|
// we should delete the checkpoint directory after a successful training
|
||||||
|
if (!skipCleanCheckpoint) {
|
||||||
|
checkpointManager.cleanPath()
|
||||||
|
}
|
||||||
|
producedBooster
|
||||||
} catch {
|
} catch {
|
||||||
case t: Throwable =>
|
case t: Throwable =>
|
||||||
// if the job was aborted due to an exception
|
// if the job was aborted due to an exception
|
||||||
|
|||||||
@ -80,6 +80,13 @@ private[spark] trait LearningTaskParams extends Params {
|
|||||||
final val cacheTrainingSet = new BooleanParam(this, "cacheTrainingSet",
|
final val cacheTrainingSet = new BooleanParam(this, "cacheTrainingSet",
|
||||||
"whether caching training data")
|
"whether caching training data")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* whether cleaning checkpoint, always cleaning by default, having this parameter majorly for
|
||||||
|
* testing
|
||||||
|
*/
|
||||||
|
final val skipCleanCheckpoint = new BooleanParam(this, "skipCleanCheckpoint",
|
||||||
|
"whether cleaning checkpoint data")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* If non-zero, the training will be stopped after a specified number
|
* If non-zero, the training will be stopped after a specified number
|
||||||
* of consecutive increases in any evaluation metric.
|
* of consecutive increases in any evaluation metric.
|
||||||
|
|||||||
@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||||
|
|
||||||
@ -67,4 +68,50 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
|
|||||||
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
|
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private def trainingWithCheckpoint(cacheData: Boolean, skipCleanCheckpoint: Boolean): Unit = {
|
||||||
|
val eval = new EvalError()
|
||||||
|
val training = buildDataFrame(Classification.train)
|
||||||
|
val testDM = new DMatrix(Classification.test.iterator)
|
||||||
|
|
||||||
|
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
||||||
|
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
|
||||||
|
val skipCleanCheckpointMap =
|
||||||
|
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
||||||
|
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||||
|
"checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++
|
||||||
|
skipCleanCheckpointMap
|
||||||
|
|
||||||
|
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
||||||
|
def error(model: Booster): Float = eval.eval(
|
||||||
|
model.predict(testDM, outPutMargin = true), testDM)
|
||||||
|
|
||||||
|
if (skipCleanCheckpoint) {
|
||||||
|
// Check only one model is kept after training
|
||||||
|
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
|
assert(files.length == 1)
|
||||||
|
assert(files.head.getPath.getName == "8.model")
|
||||||
|
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||||
|
// Train next model based on prev model
|
||||||
|
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||||
|
assert(error(tmpModel) > error(prevModel._booster))
|
||||||
|
assert(error(prevModel._booster) > error(nextModel._booster))
|
||||||
|
assert(error(nextModel._booster) < 0.1)
|
||||||
|
} else {
|
||||||
|
assert(!FileSystem.get(sc.hadoopConfiguration).exists(new Path(tmpPath)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("training with checkpoint boosters") {
|
||||||
|
trainingWithCheckpoint(cacheData = false, skipCleanCheckpoint = true)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("training with checkpoint boosters with cached training dataset") {
|
||||||
|
trainingWithCheckpoint(cacheData = true, skipCleanCheckpoint = true)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("the checkpoint file should be cleaned after a successful training") {
|
||||||
|
trainingWithCheckpoint(cacheData = false, skipCleanCheckpoint = false)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -179,60 +179,6 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
|||||||
assert(x < 0.1)
|
assert(x < 0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("training with checkpoint boosters") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
|
|
||||||
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
|
||||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
|
||||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
|
|
||||||
|
|
||||||
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
|
||||||
def error(model: Booster): Float = eval.eval(
|
|
||||||
model.predict(testDM, outPutMargin = true), testDM)
|
|
||||||
|
|
||||||
// Check only one model is kept after training
|
|
||||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
|
||||||
assert(files.length == 1)
|
|
||||||
assert(files.head.getPath.getName == "8.model")
|
|
||||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
|
||||||
|
|
||||||
// Train next model based on prev model
|
|
||||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
|
||||||
assert(error(tmpModel) > error(prevModel._booster))
|
|
||||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
|
||||||
assert(error(nextModel._booster) < 0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("training with checkpoint boosters with cached training dataset") {
|
|
||||||
val eval = new EvalError()
|
|
||||||
val training = buildDataFrame(Classification.train)
|
|
||||||
val testDM = new DMatrix(Classification.test.iterator)
|
|
||||||
|
|
||||||
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
|
||||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
|
||||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers, "cacheTrainingSet" -> true)
|
|
||||||
|
|
||||||
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
|
||||||
def error(model: Booster): Float = eval.eval(
|
|
||||||
model.predict(testDM, outPutMargin = true), testDM)
|
|
||||||
|
|
||||||
// Check only one model is kept after training
|
|
||||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
|
||||||
assert(files.length == 1)
|
|
||||||
assert(files.head.getPath.getName == "8.model")
|
|
||||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
|
||||||
|
|
||||||
// Train next model based on prev model
|
|
||||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
|
||||||
assert(error(tmpModel) > error(prevModel._booster))
|
|
||||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
|
||||||
assert(error(nextModel._booster) < 0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("repartitionForTrainingGroup with group data") {
|
test("repartitionForTrainingGroup with group data") {
|
||||||
// test different splits to cover the corner cases.
|
// test different splits to cover the corner cases.
|
||||||
for (split <- 1 to 20) {
|
for (split <- 1 to 20) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user