[jvm-packages] do not use multiple jobs to make checkpoints (#5082)

* temp

* temp

* tep

* address the comments

* fix stylistic issues

* fix

* external checkpoint
This commit is contained in:
Nan Zhu 2020-02-01 19:36:39 -08:00 committed by GitHub
parent fa26313feb
commit d7b45fbcaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 464 additions and 320 deletions

View File

@ -37,6 +37,7 @@
<spark.version>2.4.3</spark.version> <spark.version>2.4.3</spark.version>
<scala.version>2.12.8</scala.version> <scala.version>2.12.8</scala.version>
<scala.binary.version>2.12</scala.binary.version> <scala.binary.version>2.12</scala.binary.version>
<hadoop.version>2.7.3</hadoop.version>
</properties> </properties>
<repositories> <repositories>
<repository> <repository>

View File

@ -1,164 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost}
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
/**
* A class which allows user to save checkpoints every a few rounds. If a previous job fails,
* the job can restart training from a saved checkpoints instead of from scratch. This class
* provides interface and helper methods for the checkpoint functionality.
*
* NOTE: This checkpoint is different from Rabit checkpoint. Rabit checkpoint is a native-level
* checkpoint stored in executor memory. This is a checkpoint which Spark driver store on HDFS
* for every a few iterations.
*
* @param sc the sparkContext object
* @param checkpointPath the hdfs path to store checkpoints
*/
private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) {
private val logger = LogFactory.getLog("XGBoostSpark")
private val modelSuffix = ".model"
private def getPath(version: Int) = {
s"$checkpointPath/$version$modelSuffix"
}
private def getExistingVersions: Seq[Int] = {
val fs = FileSystem.get(sc.hadoopConfiguration)
if (checkpointPath.isEmpty || !fs.exists(new Path(checkpointPath))) {
Seq()
} else {
fs.listStatus(new Path(checkpointPath)).map(_.getPath.getName).collect {
case fileName if fileName.endsWith(modelSuffix) => fileName.stripSuffix(modelSuffix).toInt
}
}
}
def cleanPath(): Unit = {
if (checkpointPath != "") {
FileSystem.get(sc.hadoopConfiguration).delete(new Path(checkpointPath), true)
}
}
/**
* Load existing checkpoint with the highest version as a Booster object
*
* @return the booster with the highest version, null if no checkpoints available.
*/
private[spark] def loadCheckpointAsBooster: Booster = {
val versions = getExistingVersions
if (versions.nonEmpty) {
val version = versions.max
val fullPath = getPath(version)
val inputStream = FileSystem.get(sc.hadoopConfiguration).open(new Path(fullPath))
logger.info(s"Start training from previous booster at $fullPath")
val booster = SXGBoost.loadModel(inputStream)
booster.booster.setVersion(version)
booster
} else {
null
}
}
/**
* Clean up all previous checkpoints and save a new checkpoint
*
* @param checkpoint the checkpoint to save as an XGBoostModel
*/
private[spark] def updateCheckpoint(checkpoint: Booster): Unit = {
val fs = FileSystem.get(sc.hadoopConfiguration)
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
val fullPath = getPath(checkpoint.getVersion)
val outputStream = fs.create(new Path(fullPath), true)
logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath")
checkpoint.saveModel(outputStream)
prevModelPaths.foreach(path => fs.delete(path, true))
}
/**
* Clean up checkpoint boosters with version higher than or equal to the round.
*
* @param round the number of rounds in the current training job
*/
private[spark] def cleanUpHigherVersions(round: Int): Unit = {
val higherVersions = getExistingVersions.filter(_ / 2 >= round)
higherVersions.foreach { version =>
val fs = FileSystem.get(sc.hadoopConfiguration)
fs.delete(new Path(getPath(version)), true)
}
}
/**
* Calculate a list of checkpoint rounds to save checkpoints based on the checkpointInterval
* and total number of rounds for the training. Concretely, the checkpoint rounds start with
* prevRounds + checkpointInterval, and increase by checkpointInterval in each step until it
* reaches total number of rounds. If checkpointInterval is 0, the checkpoint will be disabled
* and the method returns Seq(round)
*
* @param checkpointInterval Period (in iterations) between checkpoints.
* @param round the total number of rounds for the training
* @return a seq of integers, each represent the index of round to save the checkpoints
*/
private[spark] def getCheckpointRounds(checkpointInterval: Int, round: Int): Seq[Int] = {
if (checkpointPath.nonEmpty && checkpointInterval > 0) {
val prevRounds = getExistingVersions.map(_ / 2)
val firstCheckpointRound = (0 +: prevRounds).max + checkpointInterval
(firstCheckpointRound until round by checkpointInterval) :+ round
} else if (checkpointInterval <= 0) {
Seq(round)
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.")
}
}
}
object CheckpointManager {
case class CheckpointParam(
checkpointPath: String,
checkpointInterval: Int,
skipCleanCheckpoint: Boolean)
private[spark] def extractParams(params: Map[String, Any]): CheckpointParam = {
val checkpointPath: String = params.get("checkpoint_path") match {
case None => ""
case Some(path: String) => path
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
" an instance of String.")
}
val checkpointInterval: Int = params.get("checkpoint_interval") match {
case None => 0
case Some(freq: Int) => freq
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
" an instance of Int.")
}
val skipCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
case None => false
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
" an instance of Boolean")
}
CheckpointParam(checkpointPath, checkpointInterval, skipCheckpointFile)
}
}

View File

@ -25,12 +25,13 @@ import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.io.FileUtils import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener} import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
@ -64,7 +65,7 @@ private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, see
private[this] case class XGBoostExecutionParams( private[this] case class XGBoostExecutionParams(
numWorkers: Int, numWorkers: Int,
round: Int, numRounds: Int,
useExternalMemory: Boolean, useExternalMemory: Boolean,
obj: ObjectiveTrait, obj: ObjectiveTrait,
eval: EvalTrait, eval: EvalTrait,
@ -72,7 +73,7 @@ private[this] case class XGBoostExecutionParams(
allowNonZeroForMissing: Boolean, allowNonZeroForMissing: Boolean,
trackerConf: TrackerConf, trackerConf: TrackerConf,
timeoutRequestWorkers: Long, timeoutRequestWorkers: Long,
checkpointParam: CheckpointParam, checkpointParam: Option[ExternalCheckpointParams],
xgbInputParams: XGBoostExecutionInputParams, xgbInputParams: XGBoostExecutionInputParams,
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams, earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
cacheTrainingSet: Boolean) { cacheTrainingSet: Boolean) {
@ -167,7 +168,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
.getOrElse("allow_non_zero_for_missing", false) .getOrElse("allow_non_zero_for_missing", false)
.asInstanceOf[Boolean] .asInstanceOf[Boolean]
validateSparkSslConf validateSparkSslConf
if (overridedParams.contains("tree_method")) { if (overridedParams.contains("tree_method")) {
require(overridedParams("tree_method") == "hist" || require(overridedParams("tree_method") == "hist" ||
overridedParams("tree_method") == "approx" || overridedParams("tree_method") == "approx" ||
@ -198,7 +198,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
" an instance of Long.") " an instance of Long.")
} }
val checkpointParam = val checkpointParam =
CheckpointManager.extractParams(overridedParams) ExternalCheckpointParams.extractParams(overridedParams)
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0) val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
.asInstanceOf[Double] .asInstanceOf[Double]
@ -339,11 +339,9 @@ object XGBoost extends Serializable {
watches: Watches, watches: Watches,
xgbExecutionParam: XGBoostExecutionParams, xgbExecutionParam: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String], rabitEnv: java.util.Map[String, String],
round: Int,
obj: ObjectiveTrait, obj: ObjectiveTrait,
eval: EvalTrait, eval: EvalTrait,
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = { prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
// to workaround the empty partitions in training dataset, // to workaround the empty partitions in training dataset,
// this might not be the best efficient implementation, see // this might not be the best efficient implementation, see
// (https://github.com/dmlc/xgboost/issues/1277) // (https://github.com/dmlc/xgboost/issues/1277)
@ -357,14 +355,23 @@ object XGBoost extends Serializable {
rabitEnv.put("DMLC_TASK_ID", taskId) rabitEnv.put("DMLC_TASK_ID", taskId)
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt) rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false") rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
val numRounds = xgbExecutionParam.numRounds
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
try { try {
Rabit.init(rabitEnv) Rabit.init(rabitEnv)
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round)) val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
val booster = SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, round, val externalCheckpointParams = xgbExecutionParam.checkpointParam
val booster = if (makeCheckpoint) {
SXGBoost.trainAndSaveCheckpoint(
watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
} else {
SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
watches.toMap, metrics, obj, eval, watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster) earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
}
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap) Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
} catch { } catch {
case xgbException: XGBoostError => case xgbException: XGBoostError =>
@ -437,7 +444,6 @@ object XGBoost extends Serializable {
trainingData: RDD[XGBLabeledPoint], trainingData: RDD[XGBLabeledPoint],
xgbExecutionParams: XGBoostExecutionParams, xgbExecutionParams: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String], rabitEnv: java.util.Map[String, String],
checkpointRound: Int,
prevBooster: Booster, prevBooster: Booster,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = { evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
if (evalSetsMap.isEmpty) { if (evalSetsMap.isEmpty) {
@ -446,8 +452,8 @@ object XGBoost extends Serializable {
processMissingValues(labeledPoints, xgbExecutionParams.missing, processMissingValues(labeledPoints, xgbExecutionParams.missing,
xgbExecutionParams.allowNonZeroForMissing), xgbExecutionParams.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParams.useExternalMemory)) getCacheDirName(xgbExecutionParams.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound, buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster) xgbExecutionParams.eval, prevBooster)
}).cache() }).cache()
} else { } else {
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers). coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
@ -459,8 +465,8 @@ object XGBoost extends Serializable {
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing)) xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
}, },
getCacheDirName(xgbExecutionParams.useExternalMemory)) getCacheDirName(xgbExecutionParams.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound, buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster) xgbExecutionParams.eval, prevBooster)
}.cache() }.cache()
} }
} }
@ -469,7 +475,6 @@ object XGBoost extends Serializable {
trainingData: RDD[Array[XGBLabeledPoint]], trainingData: RDD[Array[XGBLabeledPoint]],
xgbExecutionParam: XGBoostExecutionParams, xgbExecutionParam: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String], rabitEnv: java.util.Map[String, String],
checkpointRound: Int,
prevBooster: Booster, prevBooster: Booster,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = { evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
if (evalSetsMap.isEmpty) { if (evalSetsMap.isEmpty) {
@ -478,7 +483,7 @@ object XGBoost extends Serializable {
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing, processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
xgbExecutionParam.allowNonZeroForMissing), xgbExecutionParam.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParam.useExternalMemory)) getCacheDirName(xgbExecutionParam.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound, buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster) xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
}).cache() }).cache()
} else { } else {
@ -490,7 +495,7 @@ object XGBoost extends Serializable {
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing)) xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
}, },
getCacheDirName(xgbExecutionParam.useExternalMemory)) getCacheDirName(xgbExecutionParam.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound, buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
xgbExecutionParam.obj, xgbExecutionParam.obj,
xgbExecutionParam.eval, xgbExecutionParam.eval,
prevBooster) prevBooster)
@ -529,33 +534,30 @@ object XGBoost extends Serializable {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}") logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext) val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
val sc = trainingData.sparkContext val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
checkpointPath)
checkpointManager.cleanUpHigherVersions(xgbExecParams.round)
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet, val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
hasGroup, xgbExecParams.numWorkers) hasGroup, xgbExecParams.numWorkers)
var prevBooster = checkpointManager.loadCheckpointAsBooster val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
val checkpointManager = new ExternalCheckpointManager(
checkpointParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
checkpointManager.loadCheckpointAsScalaBooster()
}.orNull
try { try {
// Train for every ${savingRound} rounds and save the partially completed booster // Train for every ${savingRound} rounds and save the partially completed booster
val producedBooster = checkpointManager.getCheckpointRounds(
xgbExecParams.checkpointParam.checkpointInterval,
xgbExecParams.round).map {
checkpointRound: Int =>
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
try { val (booster, metrics) = try {
val parallelismTracker = new SparkParallelismTracker(sc, val parallelismTracker = new SparkParallelismTracker(sc,
xgbExecParams.timeoutRequestWorkers, xgbExecParams.timeoutRequestWorkers,
xgbExecParams.numWorkers) xgbExecParams.numWorkers)
val rabitEnv = tracker.getWorkerEnvs
tracker.getWorkerEnvs().putAll(xgbRabitParams)
val boostersAndMetrics = if (hasGroup) { val boostersAndMetrics = if (hasGroup) {
trainForRanking(transformedTrainingData.left.get, xgbExecParams, trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap) evalSetsMap)
} else { } else {
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap) prevBooster, evalSetsMap)
} }
val sparkJobThread = new Thread() { val sparkJobThread = new Thread() {
override def run() { override def run() {
@ -569,20 +571,21 @@ object XGBoost extends Serializable {
logger.info(s"Rabit returns with exit code $trackerReturnVal") logger.info(s"Rabit returns with exit code $trackerReturnVal")
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
boostersAndMetrics, sparkJobThread) boostersAndMetrics, sparkJobThread)
if (checkpointRound < xgbExecParams.round) {
prevBooster = booster
checkpointManager.updateCheckpoint(prevBooster)
}
(booster, metrics) (booster, metrics)
} finally { } finally {
tracker.stop() tracker.stop()
} }
}.last
// we should delete the checkpoint directory after a successful training // we should delete the checkpoint directory after a successful training
if (!xgbExecParams.checkpointParam.skipCleanCheckpoint) { xgbExecParams.checkpointParam.foreach {
cpParam =>
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
val checkpointManager = new ExternalCheckpointManager(
cpParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
checkpointManager.cleanPath() checkpointManager.cleanPath()
} }
producedBooster }
(booster, metrics)
} catch { } catch {
case t: Throwable => case t: Throwable =>
// if the job was aborted due to an exception // if the job was aborted due to an exception

View File

@ -24,7 +24,7 @@ private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables { with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
def needDeterministicRepartitioning: Boolean = { def needDeterministicRepartitioning: Boolean = {
getCheckpointPath.nonEmpty && getCheckpointInterval > 0 getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
} }
} }

View File

@ -18,54 +18,71 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File import java.io.File
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
import org.scalatest.FunSuite import org.scalatest.{FunSuite, Ignore}
import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.fs.{FileSystem, Path}
class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest { class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
private lazy val (model4, model8) = { private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
Map[String, Any] = {
Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism,
"checkpoint_path" -> checkpointPath, "checkpoint_interval" -> checkpointInterval)
}
private def createNewModels():
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val (model4, model8) = {
val training = buildDataFrame(Classification.train) val training = buildDataFrame(Classification.train)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", val paramMap = produceParamMap(tmpPath, 2)
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training), (new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training)) new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
} }
(tmpPath, model4, model8)
}
test("test update/load models") { test("test update/load models") {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString val (tmpPath, model4, model8) = createNewModels()
val manager = new CheckpointManager(sc, tmpPath) val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model4._booster)
manager.updateCheckpoint(model4._booster.booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1) assert(files.length == 1)
assert(files.head.getPath.getName == "4.model") assert(files.head.getPath.getName == "4.model")
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4) assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
manager.updateCheckpoint(model8._booster) manager.updateCheckpoint(model8._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1) assert(files.length == 1)
assert(files.head.getPath.getName == "8.model") assert(files.head.getPath.getName == "8.model")
assert(manager.loadCheckpointAsBooster.booster.getVersion == 8) assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
} }
test("test cleanUpHigherVersions") { test("test cleanUpHigherVersions") {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString val (tmpPath, model4, model8) = createNewModels()
val manager = new CheckpointManager(sc, tmpPath)
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model8._booster) manager.updateCheckpoint(model8._booster)
manager.cleanUpHigherVersions(round = 8) manager.cleanUpHigherVersions(8)
assert(new File(s"$tmpPath/8.model").exists()) assert(new File(s"$tmpPath/8.model").exists())
manager.cleanUpHigherVersions(round = 4) manager.cleanUpHigherVersions(4)
assert(!new File(s"$tmpPath/8.model").exists()) assert(!new File(s"$tmpPath/8.model").exists())
} }
test("test checkpoint rounds") { test("test checkpoint rounds") {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString import scala.collection.JavaConverters._
val manager = new CheckpointManager(sc, tmpPath) val (tmpPath, model4, model8) = createNewModels()
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7)) val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7)) assertResult(Seq(7))(
manager.getCheckpointRounds(0, 7).asScala)
assertResult(Seq(2, 4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
manager.updateCheckpoint(model4._booster) manager.updateCheckpoint(model4._booster)
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7)) assertResult(Seq(4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
} }
@ -75,17 +92,18 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
val testDM = new DMatrix(Classification.test.iterator) val testDM = new DMatrix(Classification.test.iterator)
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = produceParamMap(tmpPath, 2)
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map() val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
val skipCleanCheckpointMap = val skipCleanCheckpointMap =
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map() if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++
skipCleanCheckpointMap
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training) val finalParamMap = paramMap ++ cacheDataMap ++ skipCleanCheckpointMap
def error(model: Booster): Float = eval.eval(
model.predict(testDM, outPutMargin = true), testDM) val prevModel = new XGBoostClassifier(finalParamMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(model.predict(testDM, outPutMargin = true), testDM)
if (skipCleanCheckpoint) { if (skipCleanCheckpoint) {
// Check only one model is kept after training // Check only one model is kept after training
@ -95,7 +113,7 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model") val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
// Train next model based on prev model // Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training) val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) > error(prevModel._booster)) assert(error(tmpModel) >= error(prevModel._booster))
assert(error(prevModel._booster) > error(nextModel._booster)) assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1) assert(error(nextModel._booster) < 0.1)
} else { } else {

View File

@ -127,7 +127,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
" stop the application") { " stop the application") {
val spark = ss val spark = ss
import spark.implicits._ import spark.implicits._
ss.sparkContext.setLogLevel("INFO")
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense // spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector, // vector,
val testDF = Seq( val testDF = Seq(
@ -155,7 +154,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
"does not stop application") { "does not stop application") {
val spark = ss val spark = ss
import spark.implicits._ import spark.implicits._
ss.sparkContext.setLogLevel("INFO")
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense // spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector, // vector,
val testDF = Seq( val testDF = Seq(

View File

@ -17,7 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.java.XGBoostError
import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap

View File

@ -20,14 +20,12 @@ import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random import scala.util.Random
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import ml.dmlc.xgboost4j.scala.DMatrix import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.{SparkConf, SparkContext} import org.scalatest.{FunSuite, Ignore}
import org.scalatest.FunSuite
class RabitRobustnessSuite extends FunSuite with PerTest { class RabitRobustnessSuite extends FunSuite with PerTest {

View File

@ -13,6 +13,18 @@
<packaging>jar</packaging> <packaging>jar</packaging>
<dependencies> <dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-hdfs</artifactId>
<version>${hadoop.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>${hadoop.version}</version>
<scope>provided</scope>
</dependency>
<dependency> <dependency>
<groupId>junit</groupId> <groupId>junit</groupId>
<artifactId>junit</artifactId> <artifactId>junit</artifactId>

View File

@ -0,0 +1,117 @@
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.*;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
public class ExternalCheckpointManager {
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
private String modelSuffix = ".model";
private Path checkpointPath;
private FileSystem fs;
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
if (checkpointPath == null || checkpointPath.isEmpty()) {
throw new XGBoostError("cannot create ExternalCheckpointManager with null or" +
" empty checkpoint path");
}
this.checkpointPath = new Path(checkpointPath);
this.fs = fs;
}
private String getPath(int version) {
return checkpointPath.toUri().getPath() + "/" + version + modelSuffix;
}
private List<Integer> getExistingVersions() throws IOException {
if (!fs.exists(checkpointPath)) {
return new ArrayList<>();
} else {
return Arrays.stream(fs.listStatus(checkpointPath))
.map(path -> path.getPath().getName())
.filter(fileName -> fileName.endsWith(modelSuffix))
.map(fileName -> Integer.valueOf(
fileName.substring(0, fileName.length() - modelSuffix.length())))
.collect(Collectors.toList());
}
}
public void cleanPath() throws IOException {
fs.delete(checkpointPath, true);
}
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
List<Integer> versions = getExistingVersions();
if (versions.size() > 0) {
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
String checkpointPath = getPath(latestVersion);
InputStream in = fs.open(new Path(checkpointPath));
logger.info("loaded checkpoint from " + checkpointPath);
Booster booster = XGBoost.loadModel(in);
booster.setVersion(latestVersion);
return booster;
} else {
return null;
}
}
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
List<String> prevModelPaths = getExistingVersions().stream()
.map(this::getPath).collect(Collectors.toList());
String eventualPath = getPath(boosterToCheckpoint.getVersion());
String tempPath = eventualPath + "-" + UUID.randomUUID();
try (OutputStream out = fs.create(new Path(tempPath), true)) {
boosterToCheckpoint.saveModel(out);
fs.rename(new Path(tempPath), new Path(eventualPath));
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
prevModelPaths.stream().forEach(path -> {
try {
fs.delete(new Path(path), true);
} catch (IOException e) {
logger.error("failed to delete outdated checkpoint at " + path, e);
}
});
}
}
public void cleanUpHigherVersions(int currentRound) throws IOException {
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
try {
fs.delete(new Path(getPath(v)), true);
} catch (IOException e) {
logger.error("failed to clean checkpoint from other training instance", e);
}
});
}
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
throws IOException {
if (checkpointInterval > 0) {
List<Integer> prevRounds =
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
prevRounds.add(0);
int firstCheckpointRound = prevRounds.stream()
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
List<Integer> arr = new ArrayList<>();
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
arr.add(i);
}
arr.add(numOfRounds);
return arr;
} else if (checkpointInterval <= 0) {
List<Integer> l = new ArrayList<Integer>();
l.add(numOfRounds);
return l;
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
}
}
}

View File

@ -15,12 +15,16 @@
*/ */
package ml.dmlc.xgboost4j.java; package ml.dmlc.xgboost4j.java;
import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream;
import java.util.*; import java.util.*;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
/** /**
* trainer for xgboost * trainer for xgboost
@ -108,35 +112,34 @@ public class XGBoost {
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null); return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
} }
/** private static void saveCheckpoint(
* Train a booster given parameters. Booster booster,
* int iter,
* @param dtrain Data to be trained. Set<Integer> checkpointIterations,
* @param params Parameters. ExternalCheckpointManager ecm) throws XGBoostError {
* @param round Number of boosting iterations. try {
* @param watches a group of items to be evaluated during training, this allows user to watch if (checkpointIterations.contains(iter)) {
* performance on the validation set. ecm.updateCheckpoint(booster);
* @param metrics array containing the evaluation metrics for each matrix in watches for each }
* iteration } catch (Exception e) {
* @param earlyStoppingRounds if non-zero, training would be stopped logger.error("failed to save checkpoint in XGBoost4J at iteration " + iter, e);
* after a specified number of consecutive throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter, e);
* goes to the unexpected direction in any evaluation metric. }
* @param obj customized objective }
* @param eval customized evaluation
* @param booster train from scratch if set to null; train from an existing booster if not null. public static Booster trainAndSaveCheckpoint(
* @return The trained booster.
*/
public static Booster train(
DMatrix dtrain, DMatrix dtrain,
Map<String, Object> params, Map<String, Object> params,
int round, int numRounds,
Map<String, DMatrix> watches, Map<String, DMatrix> watches,
float[][] metrics, float[][] metrics,
IObjective obj, IObjective obj,
IEvaluation eval, IEvaluation eval,
int earlyStoppingRounds, int earlyStoppingRounds,
Booster booster) throws XGBoostError { Booster booster,
int checkpointInterval,
String checkpointPath,
FileSystem fs) throws XGBoostError, IOException {
//collect eval matrixs //collect eval matrixs
String[] evalNames; String[] evalNames;
DMatrix[] evalMats; DMatrix[] evalMats;
@ -144,6 +147,11 @@ public class XGBoost {
int bestIteration; int bestIteration;
List<String> names = new ArrayList<String>(); List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>(); List<DMatrix> mats = new ArrayList<DMatrix>();
Set<Integer> checkpointIterations = new HashSet<>();
ExternalCheckpointManager ecm = null;
if (checkpointPath != null) {
ecm = new ExternalCheckpointManager(checkpointPath, fs);
}
for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) { for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
names.add(evalEntry.getKey()); names.add(evalEntry.getKey());
@ -158,7 +166,7 @@ public class XGBoost {
bestScore = Float.MAX_VALUE; bestScore = Float.MAX_VALUE;
} }
bestIteration = 0; bestIteration = 0;
metrics = metrics == null ? new float[evalNames.length][round] : metrics; metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
//collect all data matrixs //collect all data matrixs
DMatrix[] allMats; DMatrix[] allMats;
@ -181,14 +189,19 @@ public class XGBoost {
booster.setParams(params); booster.setParams(params);
} }
//begin to train if (ecm != null) {
for (int iter = booster.getVersion() / 2; iter < round; iter++) { checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
}
// begin to train
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
if (booster.getVersion() % 2 == 0) { if (booster.getVersion() % 2 == 0) {
if (obj != null) { if (obj != null) {
booster.update(dtrain, obj); booster.update(dtrain, obj);
} else { } else {
booster.update(dtrain, iter); booster.update(dtrain, iter);
} }
saveCheckpoint(booster, iter, checkpointIterations, ecm);
booster.saveRabitCheckpoint(); booster.saveRabitCheckpoint();
} }
@ -239,6 +252,44 @@ public class XGBoost {
return booster; return booster;
} }
/**
* Train a booster given parameters.
*
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param metrics array containing the evaluation metrics for each matrix in watches for each
* iteration
* @param earlyStoppingRounds if non-zero, training would be stopped
* after a specified number of consecutive
* goes to the unexpected direction in any evaluation metric.
* @param obj customized objective
* @param eval customized evaluation
* @param booster train from scratch if set to null; train from an existing booster if not null.
* @return The trained booster.
*/
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
int round,
Map<String, DMatrix> watches,
float[][] metrics,
IObjective obj,
IEvaluation eval,
int earlyStoppingRounds,
Booster booster) throws XGBoostError {
try {
return trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval,
earlyStoppingRounds, booster,
-1, null, null);
} catch (IOException e) {
logger.error("training failed in xgboost4j", e);
throw new XGBoostError("training failed in xgboost4j ", e);
}
}
private static Integer tryGetIntFromObject(Object o) { private static Integer tryGetIntFromObject(Object o) {
if (o instanceof Integer) { if (o instanceof Integer) {
return (int)o; return (int)o;

View File

@ -24,4 +24,8 @@ public class XGBoostError extends Exception {
public XGBoostError(String message) { public XGBoostError(String message) {
super(message); super(message);
} }
public XGBoostError(String message, Throwable cause) {
super(message, cause);
}
} }

View File

@ -0,0 +1,37 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM}
import org.apache.hadoop.fs.FileSystem
class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem)
extends JavaECM(checkpointPath, fs) {
def updateCheckpoint(booster: Booster): Unit = {
super.updateCheckpoint(booster.booster)
}
def loadCheckpointAsScalaBooster(): Booster = {
val loadedBooster = super.loadCheckpointAsBooster()
if (loadedBooster == null) {
null
} else {
new Booster(loadedBooster)
}
}
}

View File

@ -18,14 +18,60 @@ package ml.dmlc.xgboost4j.scala
import java.io.InputStream import java.io.InputStream
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError} import ml.dmlc.xgboost4j.java.{XGBoostError, Booster => JBooster, XGBoost => JXGBoost}
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
/** /**
* XGBoost Scala Training function. * XGBoost Scala Training function.
*/ */
object XGBoost { object XGBoost {
private[scala] def trainAndSaveCheckpoint(
dtrain: DMatrix,
params: Map[String, Any],
numRounds: Int,
watches: Map[String, DMatrix] = Map(),
metrics: Array[Array[Float]] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
earlyStoppingRound: Int = 0,
prevBooster: Booster,
checkpointParams: Option[ExternalCheckpointParams]): Booster = {
val jWatches = watches.mapValues(_.jDMatrix).asJava
val jBooster = if (prevBooster == null) {
null
} else {
prevBooster.booster
}
val xgboostInJava = checkpointParams.
map(cp => {
JXGBoost.trainAndSaveCheckpoint(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
cp.checkpointInterval,
cp.checkpointPath,
new Path(cp.checkpointPath).getFileSystem(new Configuration()))
}).
getOrElse(
JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
)
if (prevBooster == null) {
new Booster(xgboostInJava)
} else {
// Avoid creating a new SBooster with the same JBooster
prevBooster
}
}
/** /**
* Train a booster given parameters. * Train a booster given parameters.
* *
@ -55,23 +101,8 @@ object XGBoost {
eval: EvalTrait = null, eval: EvalTrait = null,
earlyStoppingRound: Int = 0, earlyStoppingRound: Int = 0,
booster: Booster = null): Booster = { booster: Booster = null): Booster = {
val jWatches = watches.mapValues(_.jDMatrix).asJava trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound,
val jBooster = if (booster == null) { booster, None)
null
} else {
booster.booster
}
val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
if (booster == null) {
new Booster(xgboostInJava)
} else {
// Avoid creating a new SBooster with the same JBooster
booster
}
} }
/** /**
@ -126,3 +157,41 @@ object XGBoost {
new Booster(xgboostInJava) new Booster(xgboostInJava)
} }
} }
private[scala] case class ExternalCheckpointParams(
checkpointInterval: Int,
checkpointPath: String,
skipCleanCheckpoint: Boolean)
private[scala] object ExternalCheckpointParams {
def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = {
val checkpointPath: String = params.get("checkpoint_path") match {
case None | Some(null) | Some("") => null
case Some(path: String) => path
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
s" an instance of String, but current value is ${params("checkpoint_path")}")
}
val checkpointInterval: Int = params.get("checkpoint_interval") match {
case None => 0
case Some(freq: Int) => freq
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
" an instance of Int.")
}
val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
case None => false
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
" an instance of Boolean")
}
if (checkpointPath == null || checkpointInterval == 0) {
None
} else {
Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile))
}
}
}