[jvm-packages] Saving models into a tmp folder every a few rounds (#2964)

* [jvm-packages] Train Booster from an existing model

* Align Scala API with Java API

* Existing model should not load rabit checkpoint

* Address minor comments

* Implement saving temporary boosters and loading previous booster

* Add more unit tests for loadPrevBooster

* Add params to XGBoostEstimator

* (1) Move repartition out of the temp model saving loop (2) Address CR comments

* Catch a corner case of training next model with fewer rounds

* Address comments

* Refactor newly added methods into TmpBoosterManager

* Add two files which is missing in previous commit

* Rename TmpBooster to checkpoint
This commit is contained in:
Yun Ni
2017-12-29 08:36:41 -08:00
committed by Nan Zhu
parent eedca8c8ec
commit 9004ca03ca
11 changed files with 481 additions and 60 deletions

View File

@@ -0,0 +1,139 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.Booster
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
/**
* A class which allows user to save checkpoint boosters every a few rounds. If a previous job
* fails, the job can restart training from a saved booster instead of from scratch. This class
* provides interface and helper methods for the checkpoint functionality.
*
* @param sc the sparkContext object
* @param checkpointPath the hdfs path to store checkpoint boosters
*/
private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) {
private val logger = LogFactory.getLog("XGBoostSpark")
private val modelSuffix = ".model"
private def getPath(version: Int) = {
s"$checkpointPath/$version$modelSuffix"
}
private def getExistingVersions: Seq[Int] = {
val fs = FileSystem.get(sc.hadoopConfiguration)
if (checkpointPath.isEmpty || !fs.exists(new Path(checkpointPath))) {
Seq()
} else {
fs.listStatus(new Path(checkpointPath)).map(_.getPath.getName).collect {
case fileName if fileName.endsWith(modelSuffix) => fileName.stripSuffix(modelSuffix).toInt
}
}
}
/**
* Load existing checkpoint with the highest version.
*
* @return the booster with the highest version, null if no checkpoints available.
*/
private[spark] def loadBooster: Booster = {
val versions = getExistingVersions
if (versions.nonEmpty) {
val version = versions.max
val fullPath = getPath(version)
logger.info(s"Start training from previous booster at $fullPath")
val model = XGBoost.loadModelFromHadoopFile(fullPath)(sc)
model.booster.booster.setVersion(version)
model.booster
} else {
null
}
}
/**
* Clean up all previous models and save a new model
*
* @param model the xgboost model to save
*/
private[spark] def updateModel(model: XGBoostModel): Unit = {
val fs = FileSystem.get(sc.hadoopConfiguration)
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
val fullPath = getPath(model.version)
logger.info(s"Saving checkpoint model with version ${model.version} to $fullPath")
model.saveModelAsHadoopFile(fullPath)(sc)
prevModelPaths.foreach(path => fs.delete(path, true))
}
/**
* Clean up checkpoint boosters with version higher than or equal to the round.
*
* @param round the number of rounds in the current training job
*/
private[spark] def cleanUpHigherVersions(round: Int): Unit = {
val higherVersions = getExistingVersions.filter(_ / 2 >= round)
higherVersions.foreach { version =>
val fs = FileSystem.get(sc.hadoopConfiguration)
fs.delete(new Path(getPath(version)), true)
}
}
/**
* Calculate a list of checkpoint rounds to save checkpoints based on the savingFreq and
* total number of rounds for the training. Concretely, the saving rounds start with
* prevRounds + savingFreq, and increase by savingFreq in each step until it reaches total
* number of rounds. If savingFreq is 0, the checkpoint will be disabled and the method
* returns Seq(round)
*
* @param savingFreq the increase on rounds during each step of training
* @param round the total number of rounds for the training
* @return a seq of integers, each represent the index of round to save the checkpoints
*/
private[spark] def getSavingRounds(savingFreq: Int, round: Int): Seq[Int] = {
if (checkpointPath.nonEmpty && savingFreq > 0) {
val prevRounds = getExistingVersions.map(_ / 2)
val firstSavingRound = (0 +: prevRounds).max + savingFreq
(firstSavingRound until round by savingFreq) :+ round
} else if (savingFreq <= 0) {
Seq(round)
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.")
}
}
}
object CheckpointManager {
private[spark] def extractParams(params: Map[String, Any]): (String, Int) = {
val checkpointPath: String = params.get("checkpoint_path") match {
case None => ""
case Some(path: String) => path
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
" an instance of String.")
}
val savingFreq: Int = params.get("saving_frequency") match {
case None => 0
case Some(freq: Int) => freq
case _ => throw new IllegalArgumentException("parameter \"saving_frequency\" must be" +
" an instance of Int.")
}
(checkpointPath, savingFreq)
}
}

View File

@@ -20,7 +20,6 @@ import java.io.File
import scala.collection.mutable
import scala.util.Random
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
@@ -101,23 +100,19 @@ object XGBoost extends Serializable {
data: RDD[XGBLabeledPoint],
params: Map[String, Any],
rabitEnv: java.util.Map[String, String],
numWorkers: Int,
round: Int,
obj: ObjectiveTrait,
eval: EvalTrait,
useExternalMemory: Boolean,
missing: Float): RDD[(Booster, Map[String, Array[Float]])] = {
val partitionedData = if (data.getNumPartitions != numWorkers) {
logger.info(s"repartitioning training set to $numWorkers partitions")
data.repartition(numWorkers)
} else {
data
}
val partitionedBaseMargin = partitionedData.map(_.baseMargin)
missing: Float,
prevBooster: Booster
): RDD[(Booster, Map[String, Array[Float]])] = {
val partitionedBaseMargin = data.map(_.baseMargin)
// to workaround the empty partitions in training dataset,
// this might not be the best efficient implementation, see
// (https://github.com/dmlc/xgboost/issues/1277)
partitionedData.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) =>
data.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) =>
if (labeledPoints.isEmpty) {
throw new XGBoostError(
s"detected an empty partition in the training data, partition ID:" +
@@ -145,7 +140,7 @@ object XGBoost extends Serializable {
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
val booster = SXGBoost.train(watches.train, params, round,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds)
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
} finally {
Rabit.shutdown()
@@ -330,34 +325,58 @@ object XGBoost extends Serializable {
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
" an instance of Long.")
}
val (checkpointPath, savingFeq) = CheckpointManager.extractParams(params)
val partitionedData = repartitionForTraining(trainingData, nWorkers)
val tracker = startTracker(nWorkers, trackerConf)
try {
val sc = trainingData.sparkContext
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, trainingData.sparkContext)
val boostersAndMetrics = buildDistributedBoosters(trainingData, overriddenParams,
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
val sparkJobThread = new Thread() {
override def run() {
// force the job
boostersAndMetrics.foreachPartition(() => _)
}
val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, checkpointPath)
checkpointManager.cleanUpHigherVersions(round)
var prevBooster = checkpointManager.loadBooster
// Train for every ${savingRound} rounds and save the partially completed booster
checkpointManager.getSavingRounds(savingFeq, round).map {
savingRound: Int =>
val tracker = startTracker(nWorkers, trackerConf)
try {
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
tracker.getWorkerEnvs, savingRound, obj, eval, useExternalMemory, missing, prevBooster)
val sparkJobThread = new Thread() {
override def run() {
// force the job
boostersAndMetrics.foreachPartition(() => _)
}
}
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val isClsTask = isClassificationTask(params)
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
logger.info(s"Rabit returns with exit code $trackerReturnVal")
val model = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
sparkJobThread, isClsTask)
if (isClsTask){
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
params.getOrElse("num_class", "2").toString.toInt
}
if (savingRound < round) {
prevBooster = model.booster
checkpointManager.updateModel(model)
}
model
} finally {
tracker.stop()
}
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val isClsTask = isClassificationTask(params)
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
logger.info(s"Rabit returns with exit code $trackerReturnVal")
val model = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
sparkJobThread, isClsTask)
if (isClsTask){
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
params.getOrElse("num_class", "2").toString.toInt
}
model
} finally {
tracker.stop()
}.last
}
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
if (trainingData.getNumPartitions != nWorkers) {
logger.info(s"repartitioning training set to $nWorkers partitions")
trainingData.repartition(nWorkers)
} else {
trainingData
}
}
@@ -405,6 +424,7 @@ object XGBoost extends Serializable {
xgBoostModel.setPredictionCol(predCol)
}
/**
* Load XGBoost model from path in HDFS-compatible file system
*

View File

@@ -344,6 +344,8 @@ abstract class XGBoostModel(protected var _booster: Booster)
def booster: Booster = _booster
def version: Int = this.booster.booster.getVersion
override def copy(extra: ParamMap): XGBoostModel = defaultCopy(extra)
override def write: MLWriter = new XGBoostModel.XGBoostModelModelWriter(this)

View File

@@ -77,6 +77,21 @@ trait GeneralParams extends Params {
" request new Workers if numCores are insufficient. The timeout will be disabled if this" +
" value is set smaller than or equal to 0.")
/**
* The hdfs folder to load and save checkpoint boosters. default: `empty_string`
*/
val checkpointPath = new Param[String](this, "checkpoint_path", "the hdfs folder to load and " +
"save checkpoints. The job will try to load the existing booster as the starting point for " +
"training. If saving_frequency is also set, the job will save a checkpoint every a few rounds.")
/**
* The frequency to save checkpoint boosters. default: 0
*/
val savingFrequency = new IntParam(this, "saving_frequency", "if checkpoint_path is also set," +
" the job will save checkpoints at this frequency. If the job fails and gets restarted with" +
" same setting, it will load the existing booster instead of training from scratch." +
" Checkpoint will be disabled if set to 0.")
/**
* Rabit tracker configurations. The parameter must be provided as an instance of the
* TrackerConf class, which has the following definition:
@@ -112,6 +127,7 @@ trait GeneralParams extends Params {
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
useExternalMemory -> false, silent -> 0,
customObj -> null, customEval -> null, missing -> Float.NaN,
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
checkpointPath -> "", savingFrequency -> 0
)
}

View File

@@ -0,0 +1,80 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.nio.file.Files
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.{SparkConf, SparkContext}
class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
var sc: SparkContext = _
override def beforeAll(): Unit = {
val conf: SparkConf = new SparkConf()
.setMaster("local[*]")
.setAppName("XGBoostSuite")
sc = new SparkContext(conf)
}
private lazy val (model4, model8) = {
import DataUtils._
val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache()
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
(XGBoost.trainWithRDD(trainingRDD, paramMap, round = 2, sc.defaultParallelism),
XGBoost.trainWithRDD(trainingRDD, paramMap, round = 4, sc.defaultParallelism))
}
test("test update/load models") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateModel(model4)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadBooster.booster.getVersion == 4)
manager.updateModel(model8)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
assert(manager.loadBooster.booster.getVersion == 8)
}
test("test cleanUpHigherVersions") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateModel(model8)
manager.cleanUpHigherVersions(round = 8)
assert(new File(s"$tmpPath/8.model").exists())
manager.cleanUpHigherVersions(round = 4)
assert(!new File(s"$tmpPath/8.model").exists())
}
test("test saving rounds") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
assertResult(Seq(7))(manager.getSavingRounds(savingFreq = 0, round = 7))
assertResult(Seq(2, 4, 6, 7))(manager.getSavingRounds(savingFreq = 2, round = 7))
manager.updateModel(model4)
assertResult(Seq(4, 6, 7))(manager.getSavingRounds(2, 7))
}
}

View File

@@ -16,13 +16,14 @@
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, Vectors, Vector => SparkVector}
@@ -73,13 +74,14 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
test("build RDD containing boosters with the specified worker number") {
val trainingRDD = sc.parallelize(Classification.train)
val partitionedRDD = XGBoost.repartitionForTraining(trainingRDD, 2)
val boosterRDD = XGBoost.buildDistributedBoosters(
trainingRDD,
partitionedRDD,
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap,
new java.util.HashMap[String, String](),
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true,
missing = Float.NaN)
round = 5, eval = null, obj = null, useExternalMemory = true,
missing = Float.NaN, prevBooster = null)
val boosterCount = boosterRDD.count()
assert(boosterCount === 2)
}
@@ -335,4 +337,33 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(XGBoost.isClassificationTask(params) == isClassificationTask)
}
}
test("training with saving checkpoint boosters") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"saving_frequency" -> 2).toMap
val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers)
def error(model: XGBoostModel): Float = eval.eval(
model.booster.predict(testSetDMatrix, outPutMargin = true), testSetDMatrix)
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = XGBoost.loadModelFromHadoopFile(s"$tmpPath/8.model")
// Train next model based on prev model
val nextModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 8,
nWorkers = numWorkers)
assert(error(tmpModel) > error(prevModel))
assert(error(prevModel) > error(nextModel))
assert(error(nextModel) < 0.1)
}
}