[jvm-packages] Update docs and unify the terminology (#3024)

* [jvm-packages] Move cache files to tmp dir and delete on exit

* [jvm-packages] Update docs and unify terminology

* Address CR Comments
This commit is contained in:
Yun Ni 2018-01-16 08:16:55 -08:00 committed by Sergei Lebedev
parent 84ab74f3a5
commit 3f3f54bcad
5 changed files with 65 additions and 58 deletions

View File

@ -22,12 +22,16 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
/** /**
* A class which allows user to save checkpoint boosters every a few rounds. If a previous job * A class which allows user to save checkpoints every a few rounds. If a previous job fails,
* fails, the job can restart training from a saved booster instead of from scratch. This class * the job can restart training from a saved checkpoints instead of from scratch. This class
* provides interface and helper methods for the checkpoint functionality. * provides interface and helper methods for the checkpoint functionality.
* *
* NOTE: This checkpoint is different from Rabit checkpoint. Rabit checkpoint is a native-level
* checkpoint stored in executor memory. This is a checkpoint which Spark driver store on HDFS
* for every a few iterations.
*
* @param sc the sparkContext object * @param sc the sparkContext object
* @param checkpointPath the hdfs path to store checkpoint boosters * @param checkpointPath the hdfs path to store checkpoints
*/ */
private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) { private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) {
private val logger = LogFactory.getLog("XGBoostSpark") private val logger = LogFactory.getLog("XGBoostSpark")
@ -49,11 +53,11 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
} }
/** /**
* Load existing checkpoint with the highest version. * Load existing checkpoint with the highest version as a Booster object
* *
* @return the booster with the highest version, null if no checkpoints available. * @return the booster with the highest version, null if no checkpoints available.
*/ */
private[spark] def loadBooster: Booster = { private[spark] def loadCheckpointAsBooster: Booster = {
val versions = getExistingVersions val versions = getExistingVersions
if (versions.nonEmpty) { if (versions.nonEmpty) {
val version = versions.max val version = versions.max
@ -68,16 +72,16 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
} }
/** /**
* Clean up all previous models and save a new model * Clean up all previous checkpoints and save a new checkpoint
* *
* @param model the xgboost model to save * @param checkpoint the checkpoint to save as an XGBoostModel
*/ */
private[spark] def updateModel(model: XGBoostModel): Unit = { private[spark] def updateCheckpoint(checkpoint: XGBoostModel): Unit = {
val fs = FileSystem.get(sc.hadoopConfiguration) val fs = FileSystem.get(sc.hadoopConfiguration)
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version))) val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
val fullPath = getPath(model.version) val fullPath = getPath(checkpoint.version)
logger.info(s"Saving checkpoint model with version ${model.version} to $fullPath") logger.info(s"Saving checkpoint model with version ${checkpoint.version} to $fullPath")
model.saveModelAsHadoopFile(fullPath)(sc) checkpoint.saveModelAsHadoopFile(fullPath)(sc)
prevModelPaths.foreach(path => fs.delete(path, true)) prevModelPaths.foreach(path => fs.delete(path, true))
} }
@ -95,22 +99,22 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
} }
/** /**
* Calculate a list of checkpoint rounds to save checkpoints based on the savingFreq and * Calculate a list of checkpoint rounds to save checkpoints based on the checkpointInterval
* total number of rounds for the training. Concretely, the saving rounds start with * and total number of rounds for the training. Concretely, the checkpoint rounds start with
* prevRounds + savingFreq, and increase by savingFreq in each step until it reaches total * prevRounds + checkpointInterval, and increase by checkpointInterval in each step until it
* number of rounds. If savingFreq is 0, the checkpoint will be disabled and the method * reaches total number of rounds. If checkpointInterval is 0, the checkpoint will be disabled
* returns Seq(round) * and the method returns Seq(round)
* *
* @param savingFreq the increase on rounds during each step of training * @param checkpointInterval Period (in iterations) between checkpoints.
* @param round the total number of rounds for the training * @param round the total number of rounds for the training
* @return a seq of integers, each represent the index of round to save the checkpoints * @return a seq of integers, each represent the index of round to save the checkpoints
*/ */
private[spark] def getSavingRounds(savingFreq: Int, round: Int): Seq[Int] = { private[spark] def getCheckpointRounds(checkpointInterval: Int, round: Int): Seq[Int] = {
if (checkpointPath.nonEmpty && savingFreq > 0) { if (checkpointPath.nonEmpty && checkpointInterval > 0) {
val prevRounds = getExistingVersions.map(_ / 2) val prevRounds = getExistingVersions.map(_ / 2)
val firstSavingRound = (0 +: prevRounds).max + savingFreq val firstCheckpointRound = (0 +: prevRounds).max + checkpointInterval
(firstSavingRound until round by savingFreq) :+ round (firstCheckpointRound until round by checkpointInterval) :+ round
} else if (savingFreq <= 0) { } else if (checkpointInterval <= 0) {
Seq(round) Seq(round)
} else { } else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.") throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.")
@ -128,12 +132,12 @@ object CheckpointManager {
" an instance of String.") " an instance of String.")
} }
val savingFreq: Int = params.get("saving_frequency") match { val checkpointInterval: Int = params.get("checkpoint_interval") match {
case None => 0 case None => 0
case Some(freq: Int) => freq case Some(freq: Int) => freq
case _ => throw new IllegalArgumentException("parameter \"saving_frequency\" must be" + case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
" an instance of Int.") " an instance of Int.")
} }
(checkpointPath, savingFreq) (checkpointPath, checkpointInterval)
} }
} }

View File

@ -17,6 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import java.io.File import java.io.File
import java.nio.file.Files
import scala.collection.mutable import scala.collection.mutable
import scala.util.Random import scala.util.Random
@ -120,11 +121,9 @@ object XGBoost extends Serializable {
} }
val taskId = TaskContext.getPartitionId().toString val taskId = TaskContext.getPartitionId().toString
val cacheDirName = if (useExternalMemory) { val cacheDirName = if (useExternalMemory) {
val dir = new File(s"${TaskContext.get().stageId()}-cache-$taskId") val dir = Files.createTempDirectory(s"${TaskContext.get().stageId()}-cache-$taskId")
if (!(dir.exists() || dir.mkdirs())) { new File(dir.toUri).deleteOnExit()
throw new XGBoostError(s"failed to create cache directory: $dir") Some(dir.toAbsolutePath.toString)
}
Some(dir.toString)
} else { } else {
None None
} }
@ -325,23 +324,24 @@ object XGBoost extends Serializable {
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" + case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
" an instance of Long.") " an instance of Long.")
} }
val (checkpointPath, savingFeq) = CheckpointManager.extractParams(params) val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
val partitionedData = repartitionForTraining(trainingData, nWorkers) val partitionedData = repartitionForTraining(trainingData, nWorkers)
val sc = trainingData.sparkContext val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, checkpointPath) val checkpointManager = new CheckpointManager(sc, checkpointPath)
checkpointManager.cleanUpHigherVersions(round) checkpointManager.cleanUpHigherVersions(round)
var prevBooster = checkpointManager.loadBooster var prevBooster = checkpointManager.loadCheckpointAsBooster
// Train for every ${savingRound} rounds and save the partially completed booster // Train for every ${savingRound} rounds and save the partially completed booster
checkpointManager.getSavingRounds(savingFeq, round).map { checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
savingRound: Int => checkpointRound: Int =>
val tracker = startTracker(nWorkers, trackerConf) val tracker = startTracker(nWorkers, trackerConf)
try { try {
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers) val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc) val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams, val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
tracker.getWorkerEnvs, savingRound, obj, eval, useExternalMemory, missing, prevBooster) tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing,
prevBooster)
val sparkJobThread = new Thread() { val sparkJobThread = new Thread() {
override def run() { override def run() {
// force the job // force the job
@ -359,9 +359,9 @@ object XGBoost extends Serializable {
model.asInstanceOf[XGBoostClassificationModel].numOfClasses = model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
params.getOrElse("num_class", "2").toString.toInt params.getOrElse("num_class", "2").toString.toInt
} }
if (savingRound < round) { if (checkpointRound < round) {
prevBooster = model.booster prevBooster = model.booster
checkpointManager.updateModel(model) checkpointManager.updateCheckpoint(model)
} }
model model
} finally { } finally {

View File

@ -71,7 +71,7 @@ trait GeneralParams extends Params {
val missing = new FloatParam(this, "missing", "the value treated as missing") val missing = new FloatParam(this, "missing", "the value treated as missing")
/** /**
* the interval to check whether total numCores is no smaller than nWorkers. default: 30 minutes * the maximum time to wait for the job requesting new workers. default: 30 minutes
*/ */
val timeoutRequestWorkers = new LongParam(this, "timeout_request_workers", "the maximum time to" + val timeoutRequestWorkers = new LongParam(this, "timeout_request_workers", "the maximum time to" +
" request new Workers if numCores are insufficient. The timeout will be disabled if this" + " request new Workers if numCores are insufficient. The timeout will be disabled if this" +
@ -81,16 +81,19 @@ trait GeneralParams extends Params {
* The hdfs folder to load and save checkpoint boosters. default: `empty_string` * The hdfs folder to load and save checkpoint boosters. default: `empty_string`
*/ */
val checkpointPath = new Param[String](this, "checkpoint_path", "the hdfs folder to load and " + val checkpointPath = new Param[String](this, "checkpoint_path", "the hdfs folder to load and " +
"save checkpoints. The job will try to load the existing booster as the starting point for " + "save checkpoints. If there are existing checkpoints in checkpoint_path. The job will load " +
"training. If saving_frequency is also set, the job will save a checkpoint every a few rounds.") "the checkpoint with highest version as the starting point for training. If " +
"checkpoint_interval is also set, the job will save a checkpoint every a few rounds.")
/** /**
* The frequency to save checkpoint boosters. default: 0 * Param for set checkpoint interval (&gt;= 1) or disable checkpoint (-1). E.g. 10 means that
* the trained model will get checkpointed every 10 iterations. Note: `checkpoint_path` must
* also be set if the checkpoint interval is greater than 0.
*/ */
val savingFrequency = new IntParam(this, "saving_frequency", "if checkpoint_path is also set," + val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint " +
" the job will save checkpoints at this frequency. If the job fails and gets restarted with" + "interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained model will get " +
" same setting, it will load the existing booster instead of training from scratch." + "checkpointed every 10 iterations. Note: `checkpoint_path` must also be set if the checkpoint" +
" Checkpoint will be disabled if set to 0.") " interval is greater than 0.", (interval: Int) => interval == -1 || interval >= 1)
/** /**
* Rabit tracker configurations. The parameter must be provided as an instance of the * Rabit tracker configurations. The parameter must be provided as an instance of the
@ -128,6 +131,6 @@ trait GeneralParams extends Params {
useExternalMemory -> false, silent -> 0, useExternalMemory -> false, silent -> 0,
customObj -> null, customEval -> null, missing -> Float.NaN, customObj -> null, customEval -> null, missing -> Float.NaN,
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L, trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
checkpointPath -> "", savingFrequency -> 0 checkpointPath -> "", checkpointInterval -> -1
) )
} }

View File

@ -45,23 +45,23 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
test("test update/load models") { test("test update/load models") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath) val manager = new CheckpointManager(sc, tmpPath)
manager.updateModel(model4) manager.updateCheckpoint(model4)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1) assert(files.length == 1)
assert(files.head.getPath.getName == "4.model") assert(files.head.getPath.getName == "4.model")
assert(manager.loadBooster.booster.getVersion == 4) assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
manager.updateModel(model8) manager.updateCheckpoint(model8)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1) assert(files.length == 1)
assert(files.head.getPath.getName == "8.model") assert(files.head.getPath.getName == "8.model")
assert(manager.loadBooster.booster.getVersion == 8) assert(manager.loadCheckpointAsBooster.booster.getVersion == 8)
} }
test("test cleanUpHigherVersions") { test("test cleanUpHigherVersions") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath) val manager = new CheckpointManager(sc, tmpPath)
manager.updateModel(model8) manager.updateCheckpoint(model8)
manager.cleanUpHigherVersions(round = 8) manager.cleanUpHigherVersions(round = 8)
assert(new File(s"$tmpPath/8.model").exists()) assert(new File(s"$tmpPath/8.model").exists())
@ -69,12 +69,12 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
assert(!new File(s"$tmpPath/8.model").exists()) assert(!new File(s"$tmpPath/8.model").exists())
} }
test("test saving rounds") { test("test checkpoint rounds") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath) val manager = new CheckpointManager(sc, tmpPath)
assertResult(Seq(7))(manager.getSavingRounds(savingFreq = 0, round = 7)) assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
assertResult(Seq(2, 4, 6, 7))(manager.getSavingRounds(savingFreq = 2, round = 7)) assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
manager.updateModel(model4) manager.updateCheckpoint(model4)
assertResult(Seq(4, 6, 7))(manager.getSavingRounds(2, 7)) assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
} }
} }

View File

@ -338,7 +338,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
} }
} }
test("training with saving checkpoint boosters") { test("training with checkpoint boosters") {
import DataUtils._ import DataUtils._
val eval = new EvalError() val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML) val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
@ -347,7 +347,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1", val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"saving_frequency" -> 2).toMap "checkpoint_interval" -> 2).toMap
val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers) nWorkers = numWorkers)
def error(model: XGBoostModel): Float = eval.eval( def error(model: XGBoostModel): Float = eval.eval(