[jvm-packages] do not use multiple jobs to make checkpoints (#5082)
* temp * temp * tep * address the comments * fix stylistic issues * fix * external checkpoint
This commit is contained in:
parent
fa26313feb
commit
d7b45fbcaf
@ -37,6 +37,7 @@
|
||||
<spark.version>2.4.3</spark.version>
|
||||
<scala.version>2.12.8</scala.version>
|
||||
<scala.binary.version>2.12</scala.binary.version>
|
||||
<hadoop.version>2.7.3</hadoop.version>
|
||||
</properties>
|
||||
<repositories>
|
||||
<repository>
|
||||
|
||||
@ -1,164 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.spark.SparkContext
|
||||
|
||||
/**
|
||||
* A class which allows user to save checkpoints every a few rounds. If a previous job fails,
|
||||
* the job can restart training from a saved checkpoints instead of from scratch. This class
|
||||
* provides interface and helper methods for the checkpoint functionality.
|
||||
*
|
||||
* NOTE: This checkpoint is different from Rabit checkpoint. Rabit checkpoint is a native-level
|
||||
* checkpoint stored in executor memory. This is a checkpoint which Spark driver store on HDFS
|
||||
* for every a few iterations.
|
||||
*
|
||||
* @param sc the sparkContext object
|
||||
* @param checkpointPath the hdfs path to store checkpoints
|
||||
*/
|
||||
private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
private val modelSuffix = ".model"
|
||||
|
||||
private def getPath(version: Int) = {
|
||||
s"$checkpointPath/$version$modelSuffix"
|
||||
}
|
||||
|
||||
private def getExistingVersions: Seq[Int] = {
|
||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
||||
if (checkpointPath.isEmpty || !fs.exists(new Path(checkpointPath))) {
|
||||
Seq()
|
||||
} else {
|
||||
fs.listStatus(new Path(checkpointPath)).map(_.getPath.getName).collect {
|
||||
case fileName if fileName.endsWith(modelSuffix) => fileName.stripSuffix(modelSuffix).toInt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def cleanPath(): Unit = {
|
||||
if (checkpointPath != "") {
|
||||
FileSystem.get(sc.hadoopConfiguration).delete(new Path(checkpointPath), true)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load existing checkpoint with the highest version as a Booster object
|
||||
*
|
||||
* @return the booster with the highest version, null if no checkpoints available.
|
||||
*/
|
||||
private[spark] def loadCheckpointAsBooster: Booster = {
|
||||
val versions = getExistingVersions
|
||||
if (versions.nonEmpty) {
|
||||
val version = versions.max
|
||||
val fullPath = getPath(version)
|
||||
val inputStream = FileSystem.get(sc.hadoopConfiguration).open(new Path(fullPath))
|
||||
logger.info(s"Start training from previous booster at $fullPath")
|
||||
val booster = SXGBoost.loadModel(inputStream)
|
||||
booster.booster.setVersion(version)
|
||||
booster
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up all previous checkpoints and save a new checkpoint
|
||||
*
|
||||
* @param checkpoint the checkpoint to save as an XGBoostModel
|
||||
*/
|
||||
private[spark] def updateCheckpoint(checkpoint: Booster): Unit = {
|
||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
||||
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
|
||||
val fullPath = getPath(checkpoint.getVersion)
|
||||
val outputStream = fs.create(new Path(fullPath), true)
|
||||
logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath")
|
||||
checkpoint.saveModel(outputStream)
|
||||
prevModelPaths.foreach(path => fs.delete(path, true))
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up checkpoint boosters with version higher than or equal to the round.
|
||||
*
|
||||
* @param round the number of rounds in the current training job
|
||||
*/
|
||||
private[spark] def cleanUpHigherVersions(round: Int): Unit = {
|
||||
val higherVersions = getExistingVersions.filter(_ / 2 >= round)
|
||||
higherVersions.foreach { version =>
|
||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
||||
fs.delete(new Path(getPath(version)), true)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate a list of checkpoint rounds to save checkpoints based on the checkpointInterval
|
||||
* and total number of rounds for the training. Concretely, the checkpoint rounds start with
|
||||
* prevRounds + checkpointInterval, and increase by checkpointInterval in each step until it
|
||||
* reaches total number of rounds. If checkpointInterval is 0, the checkpoint will be disabled
|
||||
* and the method returns Seq(round)
|
||||
*
|
||||
* @param checkpointInterval Period (in iterations) between checkpoints.
|
||||
* @param round the total number of rounds for the training
|
||||
* @return a seq of integers, each represent the index of round to save the checkpoints
|
||||
*/
|
||||
private[spark] def getCheckpointRounds(checkpointInterval: Int, round: Int): Seq[Int] = {
|
||||
if (checkpointPath.nonEmpty && checkpointInterval > 0) {
|
||||
val prevRounds = getExistingVersions.map(_ / 2)
|
||||
val firstCheckpointRound = (0 +: prevRounds).max + checkpointInterval
|
||||
(firstCheckpointRound until round by checkpointInterval) :+ round
|
||||
} else if (checkpointInterval <= 0) {
|
||||
Seq(round)
|
||||
} else {
|
||||
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object CheckpointManager {
|
||||
|
||||
case class CheckpointParam(
|
||||
checkpointPath: String,
|
||||
checkpointInterval: Int,
|
||||
skipCleanCheckpoint: Boolean)
|
||||
|
||||
private[spark] def extractParams(params: Map[String, Any]): CheckpointParam = {
|
||||
val checkpointPath: String = params.get("checkpoint_path") match {
|
||||
case None => ""
|
||||
case Some(path: String) => path
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
|
||||
" an instance of String.")
|
||||
}
|
||||
|
||||
val checkpointInterval: Int = params.get("checkpoint_interval") match {
|
||||
case None => 0
|
||||
case Some(freq: Int) => freq
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
||||
" an instance of Int.")
|
||||
}
|
||||
|
||||
val skipCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
|
||||
case None => false
|
||||
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
|
||||
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
|
||||
" an instance of Boolean")
|
||||
}
|
||||
CheckpointParam(checkpointPath, checkpointInterval, skipCheckpointFile)
|
||||
}
|
||||
}
|
||||
@ -25,12 +25,13 @@ import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.commons.io.FileUtils
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.FileSystem
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
|
||||
@ -64,7 +65,7 @@ private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, see
|
||||
|
||||
private[this] case class XGBoostExecutionParams(
|
||||
numWorkers: Int,
|
||||
round: Int,
|
||||
numRounds: Int,
|
||||
useExternalMemory: Boolean,
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait,
|
||||
@ -72,7 +73,7 @@ private[this] case class XGBoostExecutionParams(
|
||||
allowNonZeroForMissing: Boolean,
|
||||
trackerConf: TrackerConf,
|
||||
timeoutRequestWorkers: Long,
|
||||
checkpointParam: CheckpointParam,
|
||||
checkpointParam: Option[ExternalCheckpointParams],
|
||||
xgbInputParams: XGBoostExecutionInputParams,
|
||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||
cacheTrainingSet: Boolean) {
|
||||
@ -167,7 +168,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
.getOrElse("allow_non_zero_for_missing", false)
|
||||
.asInstanceOf[Boolean]
|
||||
validateSparkSslConf
|
||||
|
||||
if (overridedParams.contains("tree_method")) {
|
||||
require(overridedParams("tree_method") == "hist" ||
|
||||
overridedParams("tree_method") == "approx" ||
|
||||
@ -198,7 +198,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
" an instance of Long.")
|
||||
}
|
||||
val checkpointParam =
|
||||
CheckpointManager.extractParams(overridedParams)
|
||||
ExternalCheckpointParams.extractParams(overridedParams)
|
||||
|
||||
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
|
||||
.asInstanceOf[Double]
|
||||
@ -339,11 +339,9 @@ object XGBoost extends Serializable {
|
||||
watches: Watches,
|
||||
xgbExecutionParam: XGBoostExecutionParams,
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
round: Int,
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait,
|
||||
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
|
||||
|
||||
// to workaround the empty partitions in training dataset,
|
||||
// this might not be the best efficient implementation, see
|
||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||
@ -357,14 +355,23 @@ object XGBoost extends Serializable {
|
||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
|
||||
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
|
||||
|
||||
val numRounds = xgbExecutionParam.numRounds
|
||||
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
||||
try {
|
||||
Rabit.init(rabitEnv)
|
||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
||||
val booster = SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, round,
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
||||
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||
val booster = if (makeCheckpoint) {
|
||||
SXGBoost.trainAndSaveCheckpoint(
|
||||
watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
|
||||
} else {
|
||||
SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||
}
|
||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||
} catch {
|
||||
case xgbException: XGBoostError =>
|
||||
@ -437,7 +444,6 @@ object XGBoost extends Serializable {
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
xgbExecutionParams: XGBoostExecutionParams,
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
checkpointRound: Int,
|
||||
prevBooster: Booster,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
if (evalSetsMap.isEmpty) {
|
||||
@ -446,8 +452,8 @@ object XGBoost extends Serializable {
|
||||
processMissingValues(labeledPoints, xgbExecutionParams.missing,
|
||||
xgbExecutionParams.allowNonZeroForMissing),
|
||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
||||
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
|
||||
xgbExecutionParams.eval, prevBooster)
|
||||
}).cache()
|
||||
} else {
|
||||
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
|
||||
@ -459,8 +465,8 @@ object XGBoost extends Serializable {
|
||||
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
|
||||
},
|
||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
||||
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
|
||||
xgbExecutionParams.eval, prevBooster)
|
||||
}.cache()
|
||||
}
|
||||
}
|
||||
@ -469,7 +475,6 @@ object XGBoost extends Serializable {
|
||||
trainingData: RDD[Array[XGBLabeledPoint]],
|
||||
xgbExecutionParam: XGBoostExecutionParams,
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
checkpointRound: Int,
|
||||
prevBooster: Booster,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
if (evalSetsMap.isEmpty) {
|
||||
@ -478,7 +483,7 @@ object XGBoost extends Serializable {
|
||||
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
|
||||
xgbExecutionParam.allowNonZeroForMissing),
|
||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
|
||||
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
|
||||
}).cache()
|
||||
} else {
|
||||
@ -490,7 +495,7 @@ object XGBoost extends Serializable {
|
||||
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
|
||||
},
|
||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
|
||||
xgbExecutionParam.obj,
|
||||
xgbExecutionParam.eval,
|
||||
prevBooster)
|
||||
@ -529,60 +534,58 @@ object XGBoost extends Serializable {
|
||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
|
||||
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
|
||||
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
|
||||
val sc = trainingData.sparkContext
|
||||
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
|
||||
checkpointPath)
|
||||
checkpointManager.cleanUpHigherVersions(xgbExecParams.round)
|
||||
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
|
||||
hasGroup, xgbExecParams.numWorkers)
|
||||
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
||||
val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
|
||||
val checkpointManager = new ExternalCheckpointManager(
|
||||
checkpointParam.checkpointPath,
|
||||
FileSystem.get(sc.hadoopConfiguration))
|
||||
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
|
||||
checkpointManager.loadCheckpointAsScalaBooster()
|
||||
}.orNull
|
||||
try {
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
val producedBooster = checkpointManager.getCheckpointRounds(
|
||||
xgbExecParams.checkpointParam.checkpointInterval,
|
||||
xgbExecParams.round).map {
|
||||
checkpointRound: Int =>
|
||||
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
try {
|
||||
val parallelismTracker = new SparkParallelismTracker(sc,
|
||||
xgbExecParams.timeoutRequestWorkers,
|
||||
xgbExecParams.numWorkers)
|
||||
|
||||
tracker.getWorkerEnvs().putAll(xgbRabitParams)
|
||||
val boostersAndMetrics = if (hasGroup) {
|
||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams,
|
||||
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
|
||||
} else {
|
||||
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams,
|
||||
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
|
||||
}
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
boostersAndMetrics.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkJobThread.start()
|
||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
|
||||
boostersAndMetrics, sparkJobThread)
|
||||
if (checkpointRound < xgbExecParams.round) {
|
||||
prevBooster = booster
|
||||
checkpointManager.updateCheckpoint(prevBooster)
|
||||
}
|
||||
(booster, metrics)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
val (booster, metrics) = try {
|
||||
val parallelismTracker = new SparkParallelismTracker(sc,
|
||||
xgbExecParams.timeoutRequestWorkers,
|
||||
xgbExecParams.numWorkers)
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
val boostersAndMetrics = if (hasGroup) {
|
||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
|
||||
evalSetsMap)
|
||||
} else {
|
||||
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
|
||||
prevBooster, evalSetsMap)
|
||||
}
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
boostersAndMetrics.foreachPartition(() => _)
|
||||
}
|
||||
}.last
|
||||
// we should delete the checkpoint directory after a successful training
|
||||
if (!xgbExecParams.checkpointParam.skipCleanCheckpoint) {
|
||||
checkpointManager.cleanPath()
|
||||
}
|
||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkJobThread.start()
|
||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
|
||||
boostersAndMetrics, sparkJobThread)
|
||||
(booster, metrics)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
}
|
||||
producedBooster
|
||||
// we should delete the checkpoint directory after a successful training
|
||||
xgbExecParams.checkpointParam.foreach {
|
||||
cpParam =>
|
||||
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
|
||||
val checkpointManager = new ExternalCheckpointManager(
|
||||
cpParam.checkpointPath,
|
||||
FileSystem.get(sc.hadoopConfiguration))
|
||||
checkpointManager.cleanPath()
|
||||
}
|
||||
}
|
||||
(booster, metrics)
|
||||
} catch {
|
||||
case t: Throwable =>
|
||||
// if the job was aborted due to an exception
|
||||
|
||||
@ -24,7 +24,7 @@ private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le
|
||||
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
|
||||
|
||||
def needDeterministicRepartitioning: Boolean = {
|
||||
getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
||||
getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -18,54 +18,71 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||
import org.scalatest.FunSuite
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
|
||||
import org.scalatest.{FunSuite, Ignore}
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
|
||||
private lazy val (model4, model8) = {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism)
|
||||
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
||||
private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
|
||||
Map[String, Any] = {
|
||||
Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism,
|
||||
"checkpoint_path" -> checkpointPath, "checkpoint_interval" -> checkpointInterval)
|
||||
}
|
||||
|
||||
private def createNewModels():
|
||||
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
|
||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
||||
val (model4, model8) = {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val paramMap = produceParamMap(tmpPath, 2)
|
||||
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
||||
}
|
||||
(tmpPath, model4, model8)
|
||||
}
|
||||
|
||||
test("test update/load models") {
|
||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
|
||||
manager.updateCheckpoint(model4._booster.booster)
|
||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "4.model")
|
||||
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
|
||||
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
|
||||
|
||||
manager.updateCheckpoint(model8._booster)
|
||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
assert(manager.loadCheckpointAsBooster.booster.getVersion == 8)
|
||||
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
|
||||
}
|
||||
|
||||
test("test cleanUpHigherVersions") {
|
||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
manager.updateCheckpoint(model8._booster)
|
||||
manager.cleanUpHigherVersions(round = 8)
|
||||
manager.cleanUpHigherVersions(8)
|
||||
assert(new File(s"$tmpPath/8.model").exists())
|
||||
|
||||
manager.cleanUpHigherVersions(round = 4)
|
||||
manager.cleanUpHigherVersions(4)
|
||||
assert(!new File(s"$tmpPath/8.model").exists())
|
||||
}
|
||||
|
||||
test("test checkpoint rounds") {
|
||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
|
||||
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
|
||||
import scala.collection.JavaConverters._
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
assertResult(Seq(7))(
|
||||
manager.getCheckpointRounds(0, 7).asScala)
|
||||
assertResult(Seq(2, 4, 6, 7))(
|
||||
manager.getCheckpointRounds(2, 7).asScala)
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
|
||||
assertResult(Seq(4, 6, 7))(
|
||||
manager.getCheckpointRounds(2, 7).asScala)
|
||||
}
|
||||
|
||||
|
||||
@ -75,17 +92,18 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
|
||||
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
||||
|
||||
val paramMap = produceParamMap(tmpPath, 2)
|
||||
|
||||
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
|
||||
val skipCleanCheckpointMap =
|
||||
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++
|
||||
skipCleanCheckpointMap
|
||||
|
||||
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
||||
def error(model: Booster): Float = eval.eval(
|
||||
model.predict(testDM, outPutMargin = true), testDM)
|
||||
val finalParamMap = paramMap ++ cacheDataMap ++ skipCleanCheckpointMap
|
||||
|
||||
val prevModel = new XGBoostClassifier(finalParamMap ++ Seq("num_round" -> 5)).fit(training)
|
||||
|
||||
def error(model: Booster): Float = eval.eval(model.predict(testDM, outPutMargin = true), testDM)
|
||||
|
||||
if (skipCleanCheckpoint) {
|
||||
// Check only one model is kept after training
|
||||
@ -95,7 +113,7 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||
// Train next model based on prev model
|
||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||
assert(error(tmpModel) > error(prevModel._booster))
|
||||
assert(error(tmpModel) >= error(prevModel._booster))
|
||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
||||
assert(error(nextModel._booster) < 0.1)
|
||||
} else {
|
||||
@ -127,7 +127,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
||||
" stop the application") {
|
||||
val spark = ss
|
||||
import spark.implicits._
|
||||
ss.sparkContext.setLogLevel("INFO")
|
||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
||||
// vector,
|
||||
val testDF = Seq(
|
||||
@ -155,7 +154,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
||||
"does not stop application") {
|
||||
val spark = ss
|
||||
import spark.implicits._
|
||||
ss.sparkContext.setLogLevel("INFO")
|
||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
||||
// vector,
|
||||
val testDF = Seq(
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
|
||||
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
|
||||
|
||||
@ -20,14 +20,12 @@ import java.util.concurrent.LinkedBlockingDeque
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.scalatest.{FunSuite, Ignore}
|
||||
|
||||
class RabitRobustnessSuite extends FunSuite with PerTest {
|
||||
|
||||
|
||||
@ -13,6 +13,18 @@
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-hdfs</artifactId>
|
||||
<version>${hadoop.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-common</artifactId>
|
||||
<version>${hadoop.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
|
||||
@ -0,0 +1,117 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
|
||||
public class ExternalCheckpointManager {
|
||||
|
||||
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
||||
private String modelSuffix = ".model";
|
||||
private Path checkpointPath;
|
||||
private FileSystem fs;
|
||||
|
||||
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
|
||||
if (checkpointPath == null || checkpointPath.isEmpty()) {
|
||||
throw new XGBoostError("cannot create ExternalCheckpointManager with null or" +
|
||||
" empty checkpoint path");
|
||||
}
|
||||
this.checkpointPath = new Path(checkpointPath);
|
||||
this.fs = fs;
|
||||
}
|
||||
|
||||
private String getPath(int version) {
|
||||
return checkpointPath.toUri().getPath() + "/" + version + modelSuffix;
|
||||
}
|
||||
|
||||
private List<Integer> getExistingVersions() throws IOException {
|
||||
if (!fs.exists(checkpointPath)) {
|
||||
return new ArrayList<>();
|
||||
} else {
|
||||
return Arrays.stream(fs.listStatus(checkpointPath))
|
||||
.map(path -> path.getPath().getName())
|
||||
.filter(fileName -> fileName.endsWith(modelSuffix))
|
||||
.map(fileName -> Integer.valueOf(
|
||||
fileName.substring(0, fileName.length() - modelSuffix.length())))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
|
||||
public void cleanPath() throws IOException {
|
||||
fs.delete(checkpointPath, true);
|
||||
}
|
||||
|
||||
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
|
||||
List<Integer> versions = getExistingVersions();
|
||||
if (versions.size() > 0) {
|
||||
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
|
||||
String checkpointPath = getPath(latestVersion);
|
||||
InputStream in = fs.open(new Path(checkpointPath));
|
||||
logger.info("loaded checkpoint from " + checkpointPath);
|
||||
Booster booster = XGBoost.loadModel(in);
|
||||
booster.setVersion(latestVersion);
|
||||
return booster;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
|
||||
List<String> prevModelPaths = getExistingVersions().stream()
|
||||
.map(this::getPath).collect(Collectors.toList());
|
||||
String eventualPath = getPath(boosterToCheckpoint.getVersion());
|
||||
String tempPath = eventualPath + "-" + UUID.randomUUID();
|
||||
try (OutputStream out = fs.create(new Path(tempPath), true)) {
|
||||
boosterToCheckpoint.saveModel(out);
|
||||
fs.rename(new Path(tempPath), new Path(eventualPath));
|
||||
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
|
||||
prevModelPaths.stream().forEach(path -> {
|
||||
try {
|
||||
fs.delete(new Path(path), true);
|
||||
} catch (IOException e) {
|
||||
logger.error("failed to delete outdated checkpoint at " + path, e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public void cleanUpHigherVersions(int currentRound) throws IOException {
|
||||
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
|
||||
try {
|
||||
fs.delete(new Path(getPath(v)), true);
|
||||
} catch (IOException e) {
|
||||
logger.error("failed to clean checkpoint from other training instance", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
|
||||
throws IOException {
|
||||
if (checkpointInterval > 0) {
|
||||
List<Integer> prevRounds =
|
||||
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
|
||||
prevRounds.add(0);
|
||||
int firstCheckpointRound = prevRounds.stream()
|
||||
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
|
||||
List<Integer> arr = new ArrayList<>();
|
||||
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
|
||||
arr.add(i);
|
||||
}
|
||||
arr.add(numOfRounds);
|
||||
return arr;
|
||||
} else if (checkpointInterval <= 0) {
|
||||
List<Integer> l = new ArrayList<Integer>();
|
||||
l.add(numOfRounds);
|
||||
return l;
|
||||
} else {
|
||||
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -15,12 +15,16 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.*;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
|
||||
/**
|
||||
* trainer for xgboost
|
||||
@ -108,35 +112,34 @@ public class XGBoost {
|
||||
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||
* iteration
|
||||
* @param earlyStoppingRounds if non-zero, training would be stopped
|
||||
* after a specified number of consecutive
|
||||
* goes to the unexpected direction in any evaluation metric.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||
* @return The trained booster.
|
||||
*/
|
||||
public static Booster train(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int round,
|
||||
Map<String, DMatrix> watches,
|
||||
float[][] metrics,
|
||||
IObjective obj,
|
||||
IEvaluation eval,
|
||||
int earlyStoppingRounds,
|
||||
Booster booster) throws XGBoostError {
|
||||
private static void saveCheckpoint(
|
||||
Booster booster,
|
||||
int iter,
|
||||
Set<Integer> checkpointIterations,
|
||||
ExternalCheckpointManager ecm) throws XGBoostError {
|
||||
try {
|
||||
if (checkpointIterations.contains(iter)) {
|
||||
ecm.updateCheckpoint(booster);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
logger.error("failed to save checkpoint in XGBoost4J at iteration " + iter, e);
|
||||
throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter, e);
|
||||
}
|
||||
}
|
||||
|
||||
public static Booster trainAndSaveCheckpoint(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int numRounds,
|
||||
Map<String, DMatrix> watches,
|
||||
float[][] metrics,
|
||||
IObjective obj,
|
||||
IEvaluation eval,
|
||||
int earlyStoppingRounds,
|
||||
Booster booster,
|
||||
int checkpointInterval,
|
||||
String checkpointPath,
|
||||
FileSystem fs) throws XGBoostError, IOException {
|
||||
//collect eval matrixs
|
||||
String[] evalNames;
|
||||
DMatrix[] evalMats;
|
||||
@ -144,6 +147,11 @@ public class XGBoost {
|
||||
int bestIteration;
|
||||
List<String> names = new ArrayList<String>();
|
||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||
Set<Integer> checkpointIterations = new HashSet<>();
|
||||
ExternalCheckpointManager ecm = null;
|
||||
if (checkpointPath != null) {
|
||||
ecm = new ExternalCheckpointManager(checkpointPath, fs);
|
||||
}
|
||||
|
||||
for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
|
||||
names.add(evalEntry.getKey());
|
||||
@ -158,7 +166,7 @@ public class XGBoost {
|
||||
bestScore = Float.MAX_VALUE;
|
||||
}
|
||||
bestIteration = 0;
|
||||
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
|
||||
metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
|
||||
|
||||
//collect all data matrixs
|
||||
DMatrix[] allMats;
|
||||
@ -181,14 +189,19 @@ public class XGBoost {
|
||||
booster.setParams(params);
|
||||
}
|
||||
|
||||
//begin to train
|
||||
for (int iter = booster.getVersion() / 2; iter < round; iter++) {
|
||||
if (ecm != null) {
|
||||
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
||||
}
|
||||
|
||||
// begin to train
|
||||
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
||||
if (booster.getVersion() % 2 == 0) {
|
||||
if (obj != null) {
|
||||
booster.update(dtrain, obj);
|
||||
} else {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
saveCheckpoint(booster, iter, checkpointIterations, ecm);
|
||||
booster.saveRabitCheckpoint();
|
||||
}
|
||||
|
||||
@ -224,7 +237,7 @@ public class XGBoost {
|
||||
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
|
||||
Rabit.trackerPrint(String.format(
|
||||
"early stopping after %d rounds away from the best iteration",
|
||||
earlyStoppingRounds));
|
||||
earlyStoppingRounds));
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -239,6 +252,44 @@ public class XGBoost {
|
||||
return booster;
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||
* iteration
|
||||
* @param earlyStoppingRounds if non-zero, training would be stopped
|
||||
* after a specified number of consecutive
|
||||
* goes to the unexpected direction in any evaluation metric.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||
* @return The trained booster.
|
||||
*/
|
||||
public static Booster train(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int round,
|
||||
Map<String, DMatrix> watches,
|
||||
float[][] metrics,
|
||||
IObjective obj,
|
||||
IEvaluation eval,
|
||||
int earlyStoppingRounds,
|
||||
Booster booster) throws XGBoostError {
|
||||
try {
|
||||
return trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval,
|
||||
earlyStoppingRounds, booster,
|
||||
-1, null, null);
|
||||
} catch (IOException e) {
|
||||
logger.error("training failed in xgboost4j", e);
|
||||
throw new XGBoostError("training failed in xgboost4j ", e);
|
||||
}
|
||||
}
|
||||
|
||||
private static Integer tryGetIntFromObject(Object o) {
|
||||
if (o instanceof Integer) {
|
||||
return (int)o;
|
||||
|
||||
@ -24,4 +24,8 @@ public class XGBoostError extends Exception {
|
||||
public XGBoostError(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public XGBoostError(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,37 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM}
|
||||
import org.apache.hadoop.fs.FileSystem
|
||||
|
||||
class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem)
|
||||
extends JavaECM(checkpointPath, fs) {
|
||||
|
||||
def updateCheckpoint(booster: Booster): Unit = {
|
||||
super.updateCheckpoint(booster.booster)
|
||||
}
|
||||
|
||||
def loadCheckpointAsScalaBooster(): Booster = {
|
||||
val loadedBooster = super.loadCheckpointAsBooster()
|
||||
if (loadedBooster == null) {
|
||||
null
|
||||
} else {
|
||||
new Booster(loadedBooster)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -18,14 +18,60 @@ package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.InputStream
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.java.{XGBoostError, Booster => JBooster, XGBoost => JXGBoost}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
/**
|
||||
* XGBoost Scala Training function.
|
||||
*/
|
||||
object XGBoost {
|
||||
|
||||
private[scala] def trainAndSaveCheckpoint(
|
||||
dtrain: DMatrix,
|
||||
params: Map[String, Any],
|
||||
numRounds: Int,
|
||||
watches: Map[String, DMatrix] = Map(),
|
||||
metrics: Array[Array[Float]] = null,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
earlyStoppingRound: Int = 0,
|
||||
prevBooster: Booster,
|
||||
checkpointParams: Option[ExternalCheckpointParams]): Booster = {
|
||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||
val jBooster = if (prevBooster == null) {
|
||||
null
|
||||
} else {
|
||||
prevBooster.booster
|
||||
}
|
||||
val xgboostInJava = checkpointParams.
|
||||
map(cp => {
|
||||
JXGBoost.trainAndSaveCheckpoint(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
|
||||
cp.checkpointInterval,
|
||||
cp.checkpointPath,
|
||||
new Path(cp.checkpointPath).getFileSystem(new Configuration()))
|
||||
}).
|
||||
getOrElse(
|
||||
JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||
)
|
||||
if (prevBooster == null) {
|
||||
new Booster(xgboostInJava)
|
||||
} else {
|
||||
// Avoid creating a new SBooster with the same JBooster
|
||||
prevBooster
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
@ -55,23 +101,8 @@ object XGBoost {
|
||||
eval: EvalTrait = null,
|
||||
earlyStoppingRound: Int = 0,
|
||||
booster: Booster = null): Booster = {
|
||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||
val jBooster = if (booster == null) {
|
||||
null
|
||||
} else {
|
||||
booster.booster
|
||||
}
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||
if (booster == null) {
|
||||
new Booster(xgboostInJava)
|
||||
} else {
|
||||
// Avoid creating a new SBooster with the same JBooster
|
||||
booster
|
||||
}
|
||||
trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound,
|
||||
booster, None)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -126,3 +157,41 @@ object XGBoost {
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
}
|
||||
|
||||
private[scala] case class ExternalCheckpointParams(
|
||||
checkpointInterval: Int,
|
||||
checkpointPath: String,
|
||||
skipCleanCheckpoint: Boolean)
|
||||
|
||||
private[scala] object ExternalCheckpointParams {
|
||||
|
||||
def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = {
|
||||
val checkpointPath: String = params.get("checkpoint_path") match {
|
||||
case None | Some(null) | Some("") => null
|
||||
case Some(path: String) => path
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
|
||||
s" an instance of String, but current value is ${params("checkpoint_path")}")
|
||||
}
|
||||
|
||||
val checkpointInterval: Int = params.get("checkpoint_interval") match {
|
||||
case None => 0
|
||||
case Some(freq: Int) => freq
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
||||
" an instance of Int.")
|
||||
}
|
||||
|
||||
val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
|
||||
case None => false
|
||||
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
|
||||
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
|
||||
" an instance of Boolean")
|
||||
}
|
||||
if (checkpointPath == null || checkpointInterval == 0) {
|
||||
None
|
||||
} else {
|
||||
Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user