[jvm-packages] do not use multiple jobs to make checkpoints (#5082)
* temp * temp * tep * address the comments * fix stylistic issues * fix * external checkpoint
This commit is contained in:
parent
fa26313feb
commit
d7b45fbcaf
@ -37,6 +37,7 @@
|
|||||||
<spark.version>2.4.3</spark.version>
|
<spark.version>2.4.3</spark.version>
|
||||||
<scala.version>2.12.8</scala.version>
|
<scala.version>2.12.8</scala.version>
|
||||||
<scala.binary.version>2.12</scala.binary.version>
|
<scala.binary.version>2.12</scala.binary.version>
|
||||||
|
<hadoop.version>2.7.3</hadoop.version>
|
||||||
</properties>
|
</properties>
|
||||||
<repositories>
|
<repositories>
|
||||||
<repository>
|
<repository>
|
||||||
|
|||||||
@ -1,164 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.Booster
|
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost}
|
|
||||||
import org.apache.commons.logging.LogFactory
|
|
||||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
|
||||||
import org.apache.spark.SparkContext
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A class which allows user to save checkpoints every a few rounds. If a previous job fails,
|
|
||||||
* the job can restart training from a saved checkpoints instead of from scratch. This class
|
|
||||||
* provides interface and helper methods for the checkpoint functionality.
|
|
||||||
*
|
|
||||||
* NOTE: This checkpoint is different from Rabit checkpoint. Rabit checkpoint is a native-level
|
|
||||||
* checkpoint stored in executor memory. This is a checkpoint which Spark driver store on HDFS
|
|
||||||
* for every a few iterations.
|
|
||||||
*
|
|
||||||
* @param sc the sparkContext object
|
|
||||||
* @param checkpointPath the hdfs path to store checkpoints
|
|
||||||
*/
|
|
||||||
private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) {
|
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
|
||||||
private val modelSuffix = ".model"
|
|
||||||
|
|
||||||
private def getPath(version: Int) = {
|
|
||||||
s"$checkpointPath/$version$modelSuffix"
|
|
||||||
}
|
|
||||||
|
|
||||||
private def getExistingVersions: Seq[Int] = {
|
|
||||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
|
||||||
if (checkpointPath.isEmpty || !fs.exists(new Path(checkpointPath))) {
|
|
||||||
Seq()
|
|
||||||
} else {
|
|
||||||
fs.listStatus(new Path(checkpointPath)).map(_.getPath.getName).collect {
|
|
||||||
case fileName if fileName.endsWith(modelSuffix) => fileName.stripSuffix(modelSuffix).toInt
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def cleanPath(): Unit = {
|
|
||||||
if (checkpointPath != "") {
|
|
||||||
FileSystem.get(sc.hadoopConfiguration).delete(new Path(checkpointPath), true)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load existing checkpoint with the highest version as a Booster object
|
|
||||||
*
|
|
||||||
* @return the booster with the highest version, null if no checkpoints available.
|
|
||||||
*/
|
|
||||||
private[spark] def loadCheckpointAsBooster: Booster = {
|
|
||||||
val versions = getExistingVersions
|
|
||||||
if (versions.nonEmpty) {
|
|
||||||
val version = versions.max
|
|
||||||
val fullPath = getPath(version)
|
|
||||||
val inputStream = FileSystem.get(sc.hadoopConfiguration).open(new Path(fullPath))
|
|
||||||
logger.info(s"Start training from previous booster at $fullPath")
|
|
||||||
val booster = SXGBoost.loadModel(inputStream)
|
|
||||||
booster.booster.setVersion(version)
|
|
||||||
booster
|
|
||||||
} else {
|
|
||||||
null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up all previous checkpoints and save a new checkpoint
|
|
||||||
*
|
|
||||||
* @param checkpoint the checkpoint to save as an XGBoostModel
|
|
||||||
*/
|
|
||||||
private[spark] def updateCheckpoint(checkpoint: Booster): Unit = {
|
|
||||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
|
||||||
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
|
|
||||||
val fullPath = getPath(checkpoint.getVersion)
|
|
||||||
val outputStream = fs.create(new Path(fullPath), true)
|
|
||||||
logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath")
|
|
||||||
checkpoint.saveModel(outputStream)
|
|
||||||
prevModelPaths.foreach(path => fs.delete(path, true))
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up checkpoint boosters with version higher than or equal to the round.
|
|
||||||
*
|
|
||||||
* @param round the number of rounds in the current training job
|
|
||||||
*/
|
|
||||||
private[spark] def cleanUpHigherVersions(round: Int): Unit = {
|
|
||||||
val higherVersions = getExistingVersions.filter(_ / 2 >= round)
|
|
||||||
higherVersions.foreach { version =>
|
|
||||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
|
||||||
fs.delete(new Path(getPath(version)), true)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate a list of checkpoint rounds to save checkpoints based on the checkpointInterval
|
|
||||||
* and total number of rounds for the training. Concretely, the checkpoint rounds start with
|
|
||||||
* prevRounds + checkpointInterval, and increase by checkpointInterval in each step until it
|
|
||||||
* reaches total number of rounds. If checkpointInterval is 0, the checkpoint will be disabled
|
|
||||||
* and the method returns Seq(round)
|
|
||||||
*
|
|
||||||
* @param checkpointInterval Period (in iterations) between checkpoints.
|
|
||||||
* @param round the total number of rounds for the training
|
|
||||||
* @return a seq of integers, each represent the index of round to save the checkpoints
|
|
||||||
*/
|
|
||||||
private[spark] def getCheckpointRounds(checkpointInterval: Int, round: Int): Seq[Int] = {
|
|
||||||
if (checkpointPath.nonEmpty && checkpointInterval > 0) {
|
|
||||||
val prevRounds = getExistingVersions.map(_ / 2)
|
|
||||||
val firstCheckpointRound = (0 +: prevRounds).max + checkpointInterval
|
|
||||||
(firstCheckpointRound until round by checkpointInterval) :+ round
|
|
||||||
} else if (checkpointInterval <= 0) {
|
|
||||||
Seq(round)
|
|
||||||
} else {
|
|
||||||
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
object CheckpointManager {
|
|
||||||
|
|
||||||
case class CheckpointParam(
|
|
||||||
checkpointPath: String,
|
|
||||||
checkpointInterval: Int,
|
|
||||||
skipCleanCheckpoint: Boolean)
|
|
||||||
|
|
||||||
private[spark] def extractParams(params: Map[String, Any]): CheckpointParam = {
|
|
||||||
val checkpointPath: String = params.get("checkpoint_path") match {
|
|
||||||
case None => ""
|
|
||||||
case Some(path: String) => path
|
|
||||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
|
|
||||||
" an instance of String.")
|
|
||||||
}
|
|
||||||
|
|
||||||
val checkpointInterval: Int = params.get("checkpoint_interval") match {
|
|
||||||
case None => 0
|
|
||||||
case Some(freq: Int) => freq
|
|
||||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
|
||||||
" an instance of Int.")
|
|
||||||
}
|
|
||||||
|
|
||||||
val skipCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
|
|
||||||
case None => false
|
|
||||||
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
|
|
||||||
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
|
|
||||||
" an instance of Boolean")
|
|
||||||
}
|
|
||||||
CheckpointParam(checkpointPath, checkpointInterval, skipCheckpointFile)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -25,12 +25,13 @@ import scala.collection.JavaConverters._
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||||
|
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
import org.apache.commons.io.FileUtils
|
import org.apache.commons.io.FileUtils
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
|
import org.apache.hadoop.fs.FileSystem
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
|
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
|
||||||
@ -64,7 +65,7 @@ private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, see
|
|||||||
|
|
||||||
private[this] case class XGBoostExecutionParams(
|
private[this] case class XGBoostExecutionParams(
|
||||||
numWorkers: Int,
|
numWorkers: Int,
|
||||||
round: Int,
|
numRounds: Int,
|
||||||
useExternalMemory: Boolean,
|
useExternalMemory: Boolean,
|
||||||
obj: ObjectiveTrait,
|
obj: ObjectiveTrait,
|
||||||
eval: EvalTrait,
|
eval: EvalTrait,
|
||||||
@ -72,7 +73,7 @@ private[this] case class XGBoostExecutionParams(
|
|||||||
allowNonZeroForMissing: Boolean,
|
allowNonZeroForMissing: Boolean,
|
||||||
trackerConf: TrackerConf,
|
trackerConf: TrackerConf,
|
||||||
timeoutRequestWorkers: Long,
|
timeoutRequestWorkers: Long,
|
||||||
checkpointParam: CheckpointParam,
|
checkpointParam: Option[ExternalCheckpointParams],
|
||||||
xgbInputParams: XGBoostExecutionInputParams,
|
xgbInputParams: XGBoostExecutionInputParams,
|
||||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||||
cacheTrainingSet: Boolean) {
|
cacheTrainingSet: Boolean) {
|
||||||
@ -167,7 +168,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
.getOrElse("allow_non_zero_for_missing", false)
|
.getOrElse("allow_non_zero_for_missing", false)
|
||||||
.asInstanceOf[Boolean]
|
.asInstanceOf[Boolean]
|
||||||
validateSparkSslConf
|
validateSparkSslConf
|
||||||
|
|
||||||
if (overridedParams.contains("tree_method")) {
|
if (overridedParams.contains("tree_method")) {
|
||||||
require(overridedParams("tree_method") == "hist" ||
|
require(overridedParams("tree_method") == "hist" ||
|
||||||
overridedParams("tree_method") == "approx" ||
|
overridedParams("tree_method") == "approx" ||
|
||||||
@ -198,7 +198,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
" an instance of Long.")
|
" an instance of Long.")
|
||||||
}
|
}
|
||||||
val checkpointParam =
|
val checkpointParam =
|
||||||
CheckpointManager.extractParams(overridedParams)
|
ExternalCheckpointParams.extractParams(overridedParams)
|
||||||
|
|
||||||
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
|
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
|
||||||
.asInstanceOf[Double]
|
.asInstanceOf[Double]
|
||||||
@ -339,11 +339,9 @@ object XGBoost extends Serializable {
|
|||||||
watches: Watches,
|
watches: Watches,
|
||||||
xgbExecutionParam: XGBoostExecutionParams,
|
xgbExecutionParam: XGBoostExecutionParams,
|
||||||
rabitEnv: java.util.Map[String, String],
|
rabitEnv: java.util.Map[String, String],
|
||||||
round: Int,
|
|
||||||
obj: ObjectiveTrait,
|
obj: ObjectiveTrait,
|
||||||
eval: EvalTrait,
|
eval: EvalTrait,
|
||||||
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
|
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
|
||||||
|
|
||||||
// to workaround the empty partitions in training dataset,
|
// to workaround the empty partitions in training dataset,
|
||||||
// this might not be the best efficient implementation, see
|
// this might not be the best efficient implementation, see
|
||||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||||
@ -357,14 +355,23 @@ object XGBoost extends Serializable {
|
|||||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||||
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
|
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
|
||||||
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
|
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
|
||||||
|
val numRounds = xgbExecutionParam.numRounds
|
||||||
|
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
||||||
try {
|
try {
|
||||||
Rabit.init(rabitEnv)
|
Rabit.init(rabitEnv)
|
||||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
||||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
||||||
val booster = SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, round,
|
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||||
|
val booster = if (makeCheckpoint) {
|
||||||
|
SXGBoost.trainAndSaveCheckpoint(
|
||||||
|
watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
|
||||||
|
watches.toMap, metrics, obj, eval,
|
||||||
|
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
|
||||||
|
} else {
|
||||||
|
SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
|
||||||
watches.toMap, metrics, obj, eval,
|
watches.toMap, metrics, obj, eval,
|
||||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||||
|
}
|
||||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||||
} catch {
|
} catch {
|
||||||
case xgbException: XGBoostError =>
|
case xgbException: XGBoostError =>
|
||||||
@ -437,7 +444,6 @@ object XGBoost extends Serializable {
|
|||||||
trainingData: RDD[XGBLabeledPoint],
|
trainingData: RDD[XGBLabeledPoint],
|
||||||
xgbExecutionParams: XGBoostExecutionParams,
|
xgbExecutionParams: XGBoostExecutionParams,
|
||||||
rabitEnv: java.util.Map[String, String],
|
rabitEnv: java.util.Map[String, String],
|
||||||
checkpointRound: Int,
|
|
||||||
prevBooster: Booster,
|
prevBooster: Booster,
|
||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
@ -446,8 +452,8 @@ object XGBoost extends Serializable {
|
|||||||
processMissingValues(labeledPoints, xgbExecutionParams.missing,
|
processMissingValues(labeledPoints, xgbExecutionParams.missing,
|
||||||
xgbExecutionParams.allowNonZeroForMissing),
|
xgbExecutionParams.allowNonZeroForMissing),
|
||||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
|
||||||
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
xgbExecutionParams.eval, prevBooster)
|
||||||
}).cache()
|
}).cache()
|
||||||
} else {
|
} else {
|
||||||
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
|
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
|
||||||
@ -459,8 +465,8 @@ object XGBoost extends Serializable {
|
|||||||
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
|
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
|
||||||
},
|
},
|
||||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
|
||||||
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
xgbExecutionParams.eval, prevBooster)
|
||||||
}.cache()
|
}.cache()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -469,7 +475,6 @@ object XGBoost extends Serializable {
|
|||||||
trainingData: RDD[Array[XGBLabeledPoint]],
|
trainingData: RDD[Array[XGBLabeledPoint]],
|
||||||
xgbExecutionParam: XGBoostExecutionParams,
|
xgbExecutionParam: XGBoostExecutionParams,
|
||||||
rabitEnv: java.util.Map[String, String],
|
rabitEnv: java.util.Map[String, String],
|
||||||
checkpointRound: Int,
|
|
||||||
prevBooster: Booster,
|
prevBooster: Booster,
|
||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
@ -478,7 +483,7 @@ object XGBoost extends Serializable {
|
|||||||
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
|
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
|
||||||
xgbExecutionParam.allowNonZeroForMissing),
|
xgbExecutionParam.allowNonZeroForMissing),
|
||||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
|
||||||
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
|
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
|
||||||
}).cache()
|
}).cache()
|
||||||
} else {
|
} else {
|
||||||
@ -490,7 +495,7 @@ object XGBoost extends Serializable {
|
|||||||
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
|
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
|
||||||
},
|
},
|
||||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
|
||||||
xgbExecutionParam.obj,
|
xgbExecutionParam.obj,
|
||||||
xgbExecutionParam.eval,
|
xgbExecutionParam.eval,
|
||||||
prevBooster)
|
prevBooster)
|
||||||
@ -529,33 +534,30 @@ object XGBoost extends Serializable {
|
|||||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
|
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
|
||||||
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
|
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
|
||||||
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
|
|
||||||
val sc = trainingData.sparkContext
|
val sc = trainingData.sparkContext
|
||||||
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
|
|
||||||
checkpointPath)
|
|
||||||
checkpointManager.cleanUpHigherVersions(xgbExecParams.round)
|
|
||||||
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
|
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
|
||||||
hasGroup, xgbExecParams.numWorkers)
|
hasGroup, xgbExecParams.numWorkers)
|
||||||
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
|
||||||
|
val checkpointManager = new ExternalCheckpointManager(
|
||||||
|
checkpointParam.checkpointPath,
|
||||||
|
FileSystem.get(sc.hadoopConfiguration))
|
||||||
|
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
|
||||||
|
checkpointManager.loadCheckpointAsScalaBooster()
|
||||||
|
}.orNull
|
||||||
try {
|
try {
|
||||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||||
val producedBooster = checkpointManager.getCheckpointRounds(
|
|
||||||
xgbExecParams.checkpointParam.checkpointInterval,
|
|
||||||
xgbExecParams.round).map {
|
|
||||||
checkpointRound: Int =>
|
|
||||||
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||||
try {
|
val (booster, metrics) = try {
|
||||||
val parallelismTracker = new SparkParallelismTracker(sc,
|
val parallelismTracker = new SparkParallelismTracker(sc,
|
||||||
xgbExecParams.timeoutRequestWorkers,
|
xgbExecParams.timeoutRequestWorkers,
|
||||||
xgbExecParams.numWorkers)
|
xgbExecParams.numWorkers)
|
||||||
|
val rabitEnv = tracker.getWorkerEnvs
|
||||||
tracker.getWorkerEnvs().putAll(xgbRabitParams)
|
|
||||||
val boostersAndMetrics = if (hasGroup) {
|
val boostersAndMetrics = if (hasGroup) {
|
||||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams,
|
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
|
||||||
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
|
evalSetsMap)
|
||||||
} else {
|
} else {
|
||||||
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams,
|
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
|
||||||
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
|
prevBooster, evalSetsMap)
|
||||||
}
|
}
|
||||||
val sparkJobThread = new Thread() {
|
val sparkJobThread = new Thread() {
|
||||||
override def run() {
|
override def run() {
|
||||||
@ -569,20 +571,21 @@ object XGBoost extends Serializable {
|
|||||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||||
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
|
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
|
||||||
boostersAndMetrics, sparkJobThread)
|
boostersAndMetrics, sparkJobThread)
|
||||||
if (checkpointRound < xgbExecParams.round) {
|
|
||||||
prevBooster = booster
|
|
||||||
checkpointManager.updateCheckpoint(prevBooster)
|
|
||||||
}
|
|
||||||
(booster, metrics)
|
(booster, metrics)
|
||||||
} finally {
|
} finally {
|
||||||
tracker.stop()
|
tracker.stop()
|
||||||
}
|
}
|
||||||
}.last
|
|
||||||
// we should delete the checkpoint directory after a successful training
|
// we should delete the checkpoint directory after a successful training
|
||||||
if (!xgbExecParams.checkpointParam.skipCleanCheckpoint) {
|
xgbExecParams.checkpointParam.foreach {
|
||||||
|
cpParam =>
|
||||||
|
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
|
||||||
|
val checkpointManager = new ExternalCheckpointManager(
|
||||||
|
cpParam.checkpointPath,
|
||||||
|
FileSystem.get(sc.hadoopConfiguration))
|
||||||
checkpointManager.cleanPath()
|
checkpointManager.cleanPath()
|
||||||
}
|
}
|
||||||
producedBooster
|
}
|
||||||
|
(booster, metrics)
|
||||||
} catch {
|
} catch {
|
||||||
case t: Throwable =>
|
case t: Throwable =>
|
||||||
// if the job was aborted due to an exception
|
// if the job was aborted due to an exception
|
||||||
|
|||||||
@ -24,7 +24,7 @@ private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le
|
|||||||
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
|
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
|
||||||
|
|
||||||
def needDeterministicRepartitioning: Boolean = {
|
def needDeterministicRepartitioning: Boolean = {
|
||||||
getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -18,54 +18,71 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.{FunSuite, Ignore}
|
||||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||||
|
|
||||||
class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||||
|
|
||||||
private lazy val (model4, model8) = {
|
private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
|
||||||
|
Map[String, Any] = {
|
||||||
|
Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism,
|
||||||
|
"checkpoint_path" -> checkpointPath, "checkpoint_interval" -> checkpointInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def createNewModels():
|
||||||
|
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
|
||||||
|
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
||||||
|
val (model4, model8) = {
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
val paramMap = produceParamMap(tmpPath, 2)
|
||||||
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism)
|
|
||||||
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
||||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
||||||
}
|
}
|
||||||
|
(tmpPath, model4, model8)
|
||||||
|
}
|
||||||
|
|
||||||
test("test update/load models") {
|
test("test update/load models") {
|
||||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
val (tmpPath, model4, model8) = createNewModels()
|
||||||
val manager = new CheckpointManager(sc, tmpPath)
|
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||||
manager.updateCheckpoint(model4._booster)
|
|
||||||
|
manager.updateCheckpoint(model4._booster.booster)
|
||||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
assert(files.length == 1)
|
assert(files.length == 1)
|
||||||
assert(files.head.getPath.getName == "4.model")
|
assert(files.head.getPath.getName == "4.model")
|
||||||
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
|
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
|
||||||
|
|
||||||
manager.updateCheckpoint(model8._booster)
|
manager.updateCheckpoint(model8._booster)
|
||||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
assert(files.length == 1)
|
assert(files.length == 1)
|
||||||
assert(files.head.getPath.getName == "8.model")
|
assert(files.head.getPath.getName == "8.model")
|
||||||
assert(manager.loadCheckpointAsBooster.booster.getVersion == 8)
|
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test cleanUpHigherVersions") {
|
test("test cleanUpHigherVersions") {
|
||||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
val (tmpPath, model4, model8) = createNewModels()
|
||||||
val manager = new CheckpointManager(sc, tmpPath)
|
|
||||||
|
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||||
manager.updateCheckpoint(model8._booster)
|
manager.updateCheckpoint(model8._booster)
|
||||||
manager.cleanUpHigherVersions(round = 8)
|
manager.cleanUpHigherVersions(8)
|
||||||
assert(new File(s"$tmpPath/8.model").exists())
|
assert(new File(s"$tmpPath/8.model").exists())
|
||||||
|
|
||||||
manager.cleanUpHigherVersions(round = 4)
|
manager.cleanUpHigherVersions(4)
|
||||||
assert(!new File(s"$tmpPath/8.model").exists())
|
assert(!new File(s"$tmpPath/8.model").exists())
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test checkpoint rounds") {
|
test("test checkpoint rounds") {
|
||||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
import scala.collection.JavaConverters._
|
||||||
val manager = new CheckpointManager(sc, tmpPath)
|
val (tmpPath, model4, model8) = createNewModels()
|
||||||
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
|
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||||
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
|
assertResult(Seq(7))(
|
||||||
|
manager.getCheckpointRounds(0, 7).asScala)
|
||||||
|
assertResult(Seq(2, 4, 6, 7))(
|
||||||
|
manager.getCheckpointRounds(2, 7).asScala)
|
||||||
manager.updateCheckpoint(model4._booster)
|
manager.updateCheckpoint(model4._booster)
|
||||||
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
|
assertResult(Seq(4, 6, 7))(
|
||||||
|
manager.getCheckpointRounds(2, 7).asScala)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -75,17 +92,18 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
|
|||||||
val testDM = new DMatrix(Classification.test.iterator)
|
val testDM = new DMatrix(Classification.test.iterator)
|
||||||
|
|
||||||
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
||||||
|
|
||||||
|
val paramMap = produceParamMap(tmpPath, 2)
|
||||||
|
|
||||||
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
|
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
|
||||||
val skipCleanCheckpointMap =
|
val skipCleanCheckpointMap =
|
||||||
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
|
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
|
||||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
|
||||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++
|
|
||||||
skipCleanCheckpointMap
|
|
||||||
|
|
||||||
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
val finalParamMap = paramMap ++ cacheDataMap ++ skipCleanCheckpointMap
|
||||||
def error(model: Booster): Float = eval.eval(
|
|
||||||
model.predict(testDM, outPutMargin = true), testDM)
|
val prevModel = new XGBoostClassifier(finalParamMap ++ Seq("num_round" -> 5)).fit(training)
|
||||||
|
|
||||||
|
def error(model: Booster): Float = eval.eval(model.predict(testDM, outPutMargin = true), testDM)
|
||||||
|
|
||||||
if (skipCleanCheckpoint) {
|
if (skipCleanCheckpoint) {
|
||||||
// Check only one model is kept after training
|
// Check only one model is kept after training
|
||||||
@ -95,7 +113,7 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
|
|||||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||||
// Train next model based on prev model
|
// Train next model based on prev model
|
||||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||||
assert(error(tmpModel) > error(prevModel._booster))
|
assert(error(tmpModel) >= error(prevModel._booster))
|
||||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
assert(error(prevModel._booster) > error(nextModel._booster))
|
||||||
assert(error(nextModel._booster) < 0.1)
|
assert(error(nextModel._booster) < 0.1)
|
||||||
} else {
|
} else {
|
||||||
@ -127,7 +127,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
|||||||
" stop the application") {
|
" stop the application") {
|
||||||
val spark = ss
|
val spark = ss
|
||||||
import spark.implicits._
|
import spark.implicits._
|
||||||
ss.sparkContext.setLogLevel("INFO")
|
|
||||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
||||||
// vector,
|
// vector,
|
||||||
val testDF = Seq(
|
val testDF = Seq(
|
||||||
@ -155,7 +154,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
|||||||
"does not stop application") {
|
"does not stop application") {
|
||||||
val spark = ss
|
val spark = ss
|
||||||
import spark.implicits._
|
import spark.implicits._
|
||||||
ss.sparkContext.setLogLevel("INFO")
|
|
||||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
||||||
// vector,
|
// vector,
|
||||||
val testDF = Seq(
|
val testDF = Seq(
|
||||||
|
|||||||
@ -17,7 +17,7 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
|
||||||
|
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
|
||||||
|
|||||||
@ -20,14 +20,12 @@ import java.util.concurrent.LinkedBlockingDeque
|
|||||||
|
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker}
|
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
|
|
||||||
import org.apache.spark.{SparkConf, SparkContext}
|
import org.scalatest.{FunSuite, Ignore}
|
||||||
import org.scalatest.FunSuite
|
|
||||||
|
|
||||||
|
|
||||||
class RabitRobustnessSuite extends FunSuite with PerTest {
|
class RabitRobustnessSuite extends FunSuite with PerTest {
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,18 @@
|
|||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.hadoop</groupId>
|
||||||
|
<artifactId>hadoop-hdfs</artifactId>
|
||||||
|
<version>${hadoop.version}</version>
|
||||||
|
<scope>provided</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.hadoop</groupId>
|
||||||
|
<artifactId>hadoop-common</artifactId>
|
||||||
|
<version>${hadoop.version}</version>
|
||||||
|
<scope>provided</scope>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>junit</groupId>
|
<groupId>junit</groupId>
|
||||||
<artifactId>junit</artifactId>
|
<artifactId>junit</artifactId>
|
||||||
|
|||||||
@ -0,0 +1,117 @@
|
|||||||
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import org.apache.commons.logging.Log;
|
||||||
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
import org.apache.hadoop.fs.FileSystem;
|
||||||
|
import org.apache.hadoop.fs.Path;
|
||||||
|
|
||||||
|
public class ExternalCheckpointManager {
|
||||||
|
|
||||||
|
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
||||||
|
private String modelSuffix = ".model";
|
||||||
|
private Path checkpointPath;
|
||||||
|
private FileSystem fs;
|
||||||
|
|
||||||
|
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
|
||||||
|
if (checkpointPath == null || checkpointPath.isEmpty()) {
|
||||||
|
throw new XGBoostError("cannot create ExternalCheckpointManager with null or" +
|
||||||
|
" empty checkpoint path");
|
||||||
|
}
|
||||||
|
this.checkpointPath = new Path(checkpointPath);
|
||||||
|
this.fs = fs;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getPath(int version) {
|
||||||
|
return checkpointPath.toUri().getPath() + "/" + version + modelSuffix;
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<Integer> getExistingVersions() throws IOException {
|
||||||
|
if (!fs.exists(checkpointPath)) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
} else {
|
||||||
|
return Arrays.stream(fs.listStatus(checkpointPath))
|
||||||
|
.map(path -> path.getPath().getName())
|
||||||
|
.filter(fileName -> fileName.endsWith(modelSuffix))
|
||||||
|
.map(fileName -> Integer.valueOf(
|
||||||
|
fileName.substring(0, fileName.length() - modelSuffix.length())))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void cleanPath() throws IOException {
|
||||||
|
fs.delete(checkpointPath, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
|
||||||
|
List<Integer> versions = getExistingVersions();
|
||||||
|
if (versions.size() > 0) {
|
||||||
|
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
|
||||||
|
String checkpointPath = getPath(latestVersion);
|
||||||
|
InputStream in = fs.open(new Path(checkpointPath));
|
||||||
|
logger.info("loaded checkpoint from " + checkpointPath);
|
||||||
|
Booster booster = XGBoost.loadModel(in);
|
||||||
|
booster.setVersion(latestVersion);
|
||||||
|
return booster;
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
|
||||||
|
List<String> prevModelPaths = getExistingVersions().stream()
|
||||||
|
.map(this::getPath).collect(Collectors.toList());
|
||||||
|
String eventualPath = getPath(boosterToCheckpoint.getVersion());
|
||||||
|
String tempPath = eventualPath + "-" + UUID.randomUUID();
|
||||||
|
try (OutputStream out = fs.create(new Path(tempPath), true)) {
|
||||||
|
boosterToCheckpoint.saveModel(out);
|
||||||
|
fs.rename(new Path(tempPath), new Path(eventualPath));
|
||||||
|
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
|
||||||
|
prevModelPaths.stream().forEach(path -> {
|
||||||
|
try {
|
||||||
|
fs.delete(new Path(path), true);
|
||||||
|
} catch (IOException e) {
|
||||||
|
logger.error("failed to delete outdated checkpoint at " + path, e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void cleanUpHigherVersions(int currentRound) throws IOException {
|
||||||
|
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
|
||||||
|
try {
|
||||||
|
fs.delete(new Path(getPath(v)), true);
|
||||||
|
} catch (IOException e) {
|
||||||
|
logger.error("failed to clean checkpoint from other training instance", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
|
||||||
|
throws IOException {
|
||||||
|
if (checkpointInterval > 0) {
|
||||||
|
List<Integer> prevRounds =
|
||||||
|
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
|
||||||
|
prevRounds.add(0);
|
||||||
|
int firstCheckpointRound = prevRounds.stream()
|
||||||
|
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
|
||||||
|
List<Integer> arr = new ArrayList<>();
|
||||||
|
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
|
||||||
|
arr.add(i);
|
||||||
|
}
|
||||||
|
arr.add(numOfRounds);
|
||||||
|
return arr;
|
||||||
|
} else if (checkpointInterval <= 0) {
|
||||||
|
List<Integer> l = new ArrayList<Integer>();
|
||||||
|
l.add(numOfRounds);
|
||||||
|
return l;
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -15,12 +15,16 @@
|
|||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
import org.apache.hadoop.fs.FileSystem;
|
||||||
|
import org.apache.hadoop.fs.Path;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* trainer for xgboost
|
* trainer for xgboost
|
||||||
@ -108,35 +112,34 @@ public class XGBoost {
|
|||||||
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
private static void saveCheckpoint(
|
||||||
* Train a booster given parameters.
|
Booster booster,
|
||||||
*
|
int iter,
|
||||||
* @param dtrain Data to be trained.
|
Set<Integer> checkpointIterations,
|
||||||
* @param params Parameters.
|
ExternalCheckpointManager ecm) throws XGBoostError {
|
||||||
* @param round Number of boosting iterations.
|
try {
|
||||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
if (checkpointIterations.contains(iter)) {
|
||||||
* performance on the validation set.
|
ecm.updateCheckpoint(booster);
|
||||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
}
|
||||||
* iteration
|
} catch (Exception e) {
|
||||||
* @param earlyStoppingRounds if non-zero, training would be stopped
|
logger.error("failed to save checkpoint in XGBoost4J at iteration " + iter, e);
|
||||||
* after a specified number of consecutive
|
throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter, e);
|
||||||
* goes to the unexpected direction in any evaluation metric.
|
}
|
||||||
* @param obj customized objective
|
}
|
||||||
* @param eval customized evaluation
|
|
||||||
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
public static Booster trainAndSaveCheckpoint(
|
||||||
* @return The trained booster.
|
|
||||||
*/
|
|
||||||
public static Booster train(
|
|
||||||
DMatrix dtrain,
|
DMatrix dtrain,
|
||||||
Map<String, Object> params,
|
Map<String, Object> params,
|
||||||
int round,
|
int numRounds,
|
||||||
Map<String, DMatrix> watches,
|
Map<String, DMatrix> watches,
|
||||||
float[][] metrics,
|
float[][] metrics,
|
||||||
IObjective obj,
|
IObjective obj,
|
||||||
IEvaluation eval,
|
IEvaluation eval,
|
||||||
int earlyStoppingRounds,
|
int earlyStoppingRounds,
|
||||||
Booster booster) throws XGBoostError {
|
Booster booster,
|
||||||
|
int checkpointInterval,
|
||||||
|
String checkpointPath,
|
||||||
|
FileSystem fs) throws XGBoostError, IOException {
|
||||||
//collect eval matrixs
|
//collect eval matrixs
|
||||||
String[] evalNames;
|
String[] evalNames;
|
||||||
DMatrix[] evalMats;
|
DMatrix[] evalMats;
|
||||||
@ -144,6 +147,11 @@ public class XGBoost {
|
|||||||
int bestIteration;
|
int bestIteration;
|
||||||
List<String> names = new ArrayList<String>();
|
List<String> names = new ArrayList<String>();
|
||||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||||
|
Set<Integer> checkpointIterations = new HashSet<>();
|
||||||
|
ExternalCheckpointManager ecm = null;
|
||||||
|
if (checkpointPath != null) {
|
||||||
|
ecm = new ExternalCheckpointManager(checkpointPath, fs);
|
||||||
|
}
|
||||||
|
|
||||||
for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
|
for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
|
||||||
names.add(evalEntry.getKey());
|
names.add(evalEntry.getKey());
|
||||||
@ -158,7 +166,7 @@ public class XGBoost {
|
|||||||
bestScore = Float.MAX_VALUE;
|
bestScore = Float.MAX_VALUE;
|
||||||
}
|
}
|
||||||
bestIteration = 0;
|
bestIteration = 0;
|
||||||
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
|
metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
|
||||||
|
|
||||||
//collect all data matrixs
|
//collect all data matrixs
|
||||||
DMatrix[] allMats;
|
DMatrix[] allMats;
|
||||||
@ -181,14 +189,19 @@ public class XGBoost {
|
|||||||
booster.setParams(params);
|
booster.setParams(params);
|
||||||
}
|
}
|
||||||
|
|
||||||
//begin to train
|
if (ecm != null) {
|
||||||
for (int iter = booster.getVersion() / 2; iter < round; iter++) {
|
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
||||||
|
}
|
||||||
|
|
||||||
|
// begin to train
|
||||||
|
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
||||||
if (booster.getVersion() % 2 == 0) {
|
if (booster.getVersion() % 2 == 0) {
|
||||||
if (obj != null) {
|
if (obj != null) {
|
||||||
booster.update(dtrain, obj);
|
booster.update(dtrain, obj);
|
||||||
} else {
|
} else {
|
||||||
booster.update(dtrain, iter);
|
booster.update(dtrain, iter);
|
||||||
}
|
}
|
||||||
|
saveCheckpoint(booster, iter, checkpointIterations, ecm);
|
||||||
booster.saveRabitCheckpoint();
|
booster.saveRabitCheckpoint();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -239,6 +252,44 @@ public class XGBoost {
|
|||||||
return booster;
|
return booster;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Train a booster given parameters.
|
||||||
|
*
|
||||||
|
* @param dtrain Data to be trained.
|
||||||
|
* @param params Parameters.
|
||||||
|
* @param round Number of boosting iterations.
|
||||||
|
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||||
|
* performance on the validation set.
|
||||||
|
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||||
|
* iteration
|
||||||
|
* @param earlyStoppingRounds if non-zero, training would be stopped
|
||||||
|
* after a specified number of consecutive
|
||||||
|
* goes to the unexpected direction in any evaluation metric.
|
||||||
|
* @param obj customized objective
|
||||||
|
* @param eval customized evaluation
|
||||||
|
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||||
|
* @return The trained booster.
|
||||||
|
*/
|
||||||
|
public static Booster train(
|
||||||
|
DMatrix dtrain,
|
||||||
|
Map<String, Object> params,
|
||||||
|
int round,
|
||||||
|
Map<String, DMatrix> watches,
|
||||||
|
float[][] metrics,
|
||||||
|
IObjective obj,
|
||||||
|
IEvaluation eval,
|
||||||
|
int earlyStoppingRounds,
|
||||||
|
Booster booster) throws XGBoostError {
|
||||||
|
try {
|
||||||
|
return trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval,
|
||||||
|
earlyStoppingRounds, booster,
|
||||||
|
-1, null, null);
|
||||||
|
} catch (IOException e) {
|
||||||
|
logger.error("training failed in xgboost4j", e);
|
||||||
|
throw new XGBoostError("training failed in xgboost4j ", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static Integer tryGetIntFromObject(Object o) {
|
private static Integer tryGetIntFromObject(Object o) {
|
||||||
if (o instanceof Integer) {
|
if (o instanceof Integer) {
|
||||||
return (int)o;
|
return (int)o;
|
||||||
|
|||||||
@ -24,4 +24,8 @@ public class XGBoostError extends Exception {
|
|||||||
public XGBoostError(String message) {
|
public XGBoostError(String message) {
|
||||||
super(message);
|
super(message);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public XGBoostError(String message, Throwable cause) {
|
||||||
|
super(message, cause);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,37 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM}
|
||||||
|
import org.apache.hadoop.fs.FileSystem
|
||||||
|
|
||||||
|
class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem)
|
||||||
|
extends JavaECM(checkpointPath, fs) {
|
||||||
|
|
||||||
|
def updateCheckpoint(booster: Booster): Unit = {
|
||||||
|
super.updateCheckpoint(booster.booster)
|
||||||
|
}
|
||||||
|
|
||||||
|
def loadCheckpointAsScalaBooster(): Booster = {
|
||||||
|
val loadedBooster = super.loadCheckpointAsBooster()
|
||||||
|
if (loadedBooster == null) {
|
||||||
|
null
|
||||||
|
} else {
|
||||||
|
new Booster(loadedBooster)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -18,14 +18,60 @@ package ml.dmlc.xgboost4j.scala
|
|||||||
|
|
||||||
import java.io.InputStream
|
import java.io.InputStream
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
|
import ml.dmlc.xgboost4j.java.{XGBoostError, Booster => JBooster, XGBoost => JXGBoost}
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
|
import org.apache.hadoop.conf.Configuration
|
||||||
|
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* XGBoost Scala Training function.
|
* XGBoost Scala Training function.
|
||||||
*/
|
*/
|
||||||
object XGBoost {
|
object XGBoost {
|
||||||
|
|
||||||
|
private[scala] def trainAndSaveCheckpoint(
|
||||||
|
dtrain: DMatrix,
|
||||||
|
params: Map[String, Any],
|
||||||
|
numRounds: Int,
|
||||||
|
watches: Map[String, DMatrix] = Map(),
|
||||||
|
metrics: Array[Array[Float]] = null,
|
||||||
|
obj: ObjectiveTrait = null,
|
||||||
|
eval: EvalTrait = null,
|
||||||
|
earlyStoppingRound: Int = 0,
|
||||||
|
prevBooster: Booster,
|
||||||
|
checkpointParams: Option[ExternalCheckpointParams]): Booster = {
|
||||||
|
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||||
|
val jBooster = if (prevBooster == null) {
|
||||||
|
null
|
||||||
|
} else {
|
||||||
|
prevBooster.booster
|
||||||
|
}
|
||||||
|
val xgboostInJava = checkpointParams.
|
||||||
|
map(cp => {
|
||||||
|
JXGBoost.trainAndSaveCheckpoint(
|
||||||
|
dtrain.jDMatrix,
|
||||||
|
// we have to filter null value for customized obj and eval
|
||||||
|
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||||
|
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
|
||||||
|
cp.checkpointInterval,
|
||||||
|
cp.checkpointPath,
|
||||||
|
new Path(cp.checkpointPath).getFileSystem(new Configuration()))
|
||||||
|
}).
|
||||||
|
getOrElse(
|
||||||
|
JXGBoost.train(
|
||||||
|
dtrain.jDMatrix,
|
||||||
|
// we have to filter null value for customized obj and eval
|
||||||
|
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||||
|
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||||
|
)
|
||||||
|
if (prevBooster == null) {
|
||||||
|
new Booster(xgboostInJava)
|
||||||
|
} else {
|
||||||
|
// Avoid creating a new SBooster with the same JBooster
|
||||||
|
prevBooster
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Train a booster given parameters.
|
* Train a booster given parameters.
|
||||||
*
|
*
|
||||||
@ -55,23 +101,8 @@ object XGBoost {
|
|||||||
eval: EvalTrait = null,
|
eval: EvalTrait = null,
|
||||||
earlyStoppingRound: Int = 0,
|
earlyStoppingRound: Int = 0,
|
||||||
booster: Booster = null): Booster = {
|
booster: Booster = null): Booster = {
|
||||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound,
|
||||||
val jBooster = if (booster == null) {
|
booster, None)
|
||||||
null
|
|
||||||
} else {
|
|
||||||
booster.booster
|
|
||||||
}
|
|
||||||
val xgboostInJava = JXGBoost.train(
|
|
||||||
dtrain.jDMatrix,
|
|
||||||
// we have to filter null value for customized obj and eval
|
|
||||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
|
||||||
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
|
||||||
if (booster == null) {
|
|
||||||
new Booster(xgboostInJava)
|
|
||||||
} else {
|
|
||||||
// Avoid creating a new SBooster with the same JBooster
|
|
||||||
booster
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -126,3 +157,41 @@ object XGBoost {
|
|||||||
new Booster(xgboostInJava)
|
new Booster(xgboostInJava)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private[scala] case class ExternalCheckpointParams(
|
||||||
|
checkpointInterval: Int,
|
||||||
|
checkpointPath: String,
|
||||||
|
skipCleanCheckpoint: Boolean)
|
||||||
|
|
||||||
|
private[scala] object ExternalCheckpointParams {
|
||||||
|
|
||||||
|
def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = {
|
||||||
|
val checkpointPath: String = params.get("checkpoint_path") match {
|
||||||
|
case None | Some(null) | Some("") => null
|
||||||
|
case Some(path: String) => path
|
||||||
|
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
|
||||||
|
s" an instance of String, but current value is ${params("checkpoint_path")}")
|
||||||
|
}
|
||||||
|
|
||||||
|
val checkpointInterval: Int = params.get("checkpoint_interval") match {
|
||||||
|
case None => 0
|
||||||
|
case Some(freq: Int) => freq
|
||||||
|
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
||||||
|
" an instance of Int.")
|
||||||
|
}
|
||||||
|
|
||||||
|
val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
|
||||||
|
case None => false
|
||||||
|
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
|
||||||
|
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
|
||||||
|
" an instance of Boolean")
|
||||||
|
}
|
||||||
|
if (checkpointPath == null || checkpointInterval == 0) {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user