[jvm-packages] Refactor XGBoost.scala to put all params processing in one place (#4815)
* cleaning checkpoint file after a successful file * address comments * refactor xgboost.scala to avoid multiple changes when adding params * consolidate params * fix compilation issue * fix failed test * fix wrong name * tyep conversion
This commit is contained in:
parent
830e73901d
commit
0184eb5d02
@ -24,6 +24,8 @@ import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoost.logger
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
@ -55,6 +57,172 @@ object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||
}
|
||||
|
||||
private[this] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
|
||||
maximizeEvalMetrics: Boolean)
|
||||
|
||||
private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
|
||||
|
||||
private[this] case class XGBoostExecutionParams(
|
||||
numWorkers: Int,
|
||||
round: Int,
|
||||
useExternalMemory: Boolean,
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait,
|
||||
missing: Float,
|
||||
trackerConf: TrackerConf,
|
||||
timeoutRequestWorkers: Long,
|
||||
checkpointParam: CheckpointParam,
|
||||
xgbInputParams: XGBoostExecutionInputParams,
|
||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||
cacheTrainingSet: Boolean) {
|
||||
|
||||
private var rawParamMap: Map[String, Any] = _
|
||||
|
||||
def setRawParamMap(inputMap: Map[String, Any]): Unit = {
|
||||
rawParamMap = inputMap
|
||||
}
|
||||
|
||||
def toMap: Map[String, Any] = {
|
||||
rawParamMap
|
||||
}
|
||||
}
|
||||
|
||||
private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], sc: SparkContext){
|
||||
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
private val overridedParams = overrideParams(rawParams, sc)
|
||||
|
||||
/**
|
||||
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
|
||||
* If so, throw an exception unless this safety measure has been explicitly overridden
|
||||
* via conf `xgboost.spark.ignoreSsl`.
|
||||
*/
|
||||
private def validateSparkSslConf: Unit = {
|
||||
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
|
||||
SparkSession.getActiveSession match {
|
||||
case Some(ss) =>
|
||||
(ss.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean,
|
||||
ss.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean)
|
||||
case None =>
|
||||
(sc.getConf.getBoolean("spark.ssl.enabled", false),
|
||||
sc.getConf.getBoolean("xgboost.spark.ignoreSsl", false))
|
||||
}
|
||||
if (sparkSslEnabled) {
|
||||
if (xgboostSparkIgnoreSsl) {
|
||||
logger.warn(s"spark-xgboost is being run without encrypting data in transit! " +
|
||||
s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.")
|
||||
} else {
|
||||
throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " +
|
||||
"in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " +
|
||||
"To override this protection and still use xgboost-spark at your own risk, " +
|
||||
"you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* we should not include any nested structure in the output of this function as the map is
|
||||
* eventually to be feed to xgboost4j layer
|
||||
*/
|
||||
private def overrideParams(
|
||||
params: Map[String, Any],
|
||||
sc: SparkContext): Map[String, Any] = {
|
||||
val coresPerTask = sc.getConf.getInt("spark.task.cpus", 1)
|
||||
var overridedParams = params
|
||||
if (overridedParams.contains("nthread")) {
|
||||
val nThread = overridedParams("nthread").toString.toInt
|
||||
require(nThread <= coresPerTask,
|
||||
s"the nthread configuration ($nThread) must be no larger than " +
|
||||
s"spark.task.cpus ($coresPerTask)")
|
||||
} else {
|
||||
overridedParams = overridedParams + ("nthread" -> coresPerTask)
|
||||
}
|
||||
|
||||
val numEarlyStoppingRounds = overridedParams.getOrElse(
|
||||
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
||||
overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds
|
||||
if (numEarlyStoppingRounds > 0 &&
|
||||
!overridedParams.contains("maximize_evaluation_metrics")) {
|
||||
if (overridedParams.contains("custom_eval")) {
|
||||
throw new IllegalArgumentException("custom_eval does not support early stopping")
|
||||
}
|
||||
val eval_metric = overridedParams("eval_metric").toString
|
||||
val maximize = LearningTaskParams.evalMetricsToMaximize contains eval_metric
|
||||
logger.info("parameter \"maximize_evaluation_metrics\" is set to " + maximize)
|
||||
overridedParams += ("maximize_evaluation_metrics" -> maximize)
|
||||
}
|
||||
overridedParams
|
||||
}
|
||||
|
||||
def buildXGBRuntimeParams: XGBoostExecutionParams = {
|
||||
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
|
||||
val round = overridedParams("num_round").asInstanceOf[Int]
|
||||
val useExternalMemory = overridedParams("use_external_memory").asInstanceOf[Boolean]
|
||||
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
||||
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
||||
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
||||
validateSparkSslConf
|
||||
|
||||
if (overridedParams.contains("tree_method")) {
|
||||
require(overridedParams("tree_method") == "hist" ||
|
||||
overridedParams("tree_method") == "approx" ||
|
||||
overridedParams("tree_method") == "auto", "xgboost4j-spark only supports tree_method as" +
|
||||
" 'hist', 'approx' and 'auto'")
|
||||
}
|
||||
if (overridedParams.contains("train_test_ratio")) {
|
||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
||||
"'eval_set_names'")
|
||||
}
|
||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||
if (obj != null) {
|
||||
require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " +
|
||||
"is not defined, you have to specify the objective type as classification or regression" +
|
||||
" with a customized objective function")
|
||||
}
|
||||
val trackerConf = overridedParams.get("tracker_conf") match {
|
||||
case None => TrackerConf()
|
||||
case Some(conf: TrackerConf) => conf
|
||||
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
||||
"instance of TrackerConf.")
|
||||
}
|
||||
val timeoutRequestWorkers: Long = overridedParams.get("timeout_request_workers") match {
|
||||
case None => 0L
|
||||
case Some(interval: Long) => interval
|
||||
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
|
||||
" an instance of Long.")
|
||||
}
|
||||
val checkpointParam =
|
||||
CheckpointManager.extractParams(overridedParams)
|
||||
|
||||
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
|
||||
.asInstanceOf[Double]
|
||||
val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long]
|
||||
val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed)
|
||||
|
||||
val earlyStoppingRounds = overridedParams.getOrElse(
|
||||
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
||||
val maximizeEvalMetrics = overridedParams.getOrElse(
|
||||
"maximize_evaluation_metrics", true).asInstanceOf[Boolean]
|
||||
val xgbExecEarlyStoppingParams = XGBoostExecutionEarlyStoppingParams(earlyStoppingRounds,
|
||||
maximizeEvalMetrics)
|
||||
|
||||
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
|
||||
.asInstanceOf[Boolean]
|
||||
|
||||
val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
|
||||
missing, trackerConf,
|
||||
timeoutRequestWorkers,
|
||||
checkpointParam,
|
||||
inputParams,
|
||||
xgbExecEarlyStoppingParams,
|
||||
cacheTrainingSet)
|
||||
xgbExecParam.setRawParamMap(overridedParams)
|
||||
xgbExecParam
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Traing data group in a RDD partition.
|
||||
* @param groupId The group id
|
||||
@ -136,7 +304,7 @@ object XGBoost extends Serializable {
|
||||
|
||||
private def buildDistributedBooster(
|
||||
watches: Watches,
|
||||
params: Map[String, Any],
|
||||
xgbExecutionParam: XGBoostExecutionParams,
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
round: Int,
|
||||
obj: ObjectiveTrait,
|
||||
@ -157,24 +325,9 @@ object XGBoost extends Serializable {
|
||||
|
||||
try {
|
||||
Rabit.init(rabitEnv)
|
||||
|
||||
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val overridedParams = if (numEarlyStoppingRounds > 0 &&
|
||||
!params.contains("maximize_evaluation_metrics")) {
|
||||
if (params.contains("custom_eval")) {
|
||||
throw new IllegalArgumentException("maximize_evaluation_metrics has to be "
|
||||
+ "specified when custom_eval is set")
|
||||
}
|
||||
val eval_metric = params("eval_metric").toString
|
||||
val maximize = LearningTaskParams.evalMetricsToMaximize contains eval_metric
|
||||
logger.info("parameter \"maximize_evaluation_metrics\" is set to " + maximize)
|
||||
params + ("maximize_evaluation_metrics" -> maximize)
|
||||
} else {
|
||||
params
|
||||
}
|
||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
||||
val booster = SXGBoost.train(watches.toMap("train"), overridedParams, round,
|
||||
val booster = SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, round,
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||
@ -188,22 +341,6 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
private def overrideParamsAccordingToTaskCPUs(
|
||||
params: Map[String, Any],
|
||||
sc: SparkContext): Map[String, Any] = {
|
||||
val coresPerTask = sc.getConf.getInt("spark.task.cpus", 1)
|
||||
var overridedParams = params
|
||||
if (overridedParams.contains("nthread")) {
|
||||
val nThread = overridedParams("nthread").toString.toInt
|
||||
require(nThread <= coresPerTask,
|
||||
s"the nthread configuration ($nThread) must be no larger than " +
|
||||
s"spark.task.cpus ($coresPerTask)")
|
||||
} else {
|
||||
overridedParams = params + ("nthread" -> coresPerTask)
|
||||
}
|
||||
overridedParams
|
||||
}
|
||||
|
||||
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||
val tracker: IRabitTracker = trackerConf.trackerImpl match {
|
||||
case "scala" => new RabitTracker(nWorkers)
|
||||
@ -261,138 +398,64 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
|
||||
* If so, throw an exception unless this safety measure has been explicitly overridden
|
||||
* via conf `xgboost.spark.ignoreSsl`.
|
||||
*
|
||||
* @param sc SparkContext for the training dataset. When looking for the confs, this method
|
||||
* first checks for an active SparkSession. If one is not available, it falls back
|
||||
* to this SparkContext.
|
||||
*/
|
||||
private def validateSparkSslConf(sc: SparkContext): Unit = {
|
||||
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
|
||||
SparkSession.getActiveSession match {
|
||||
case Some(ss) =>
|
||||
(ss.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean,
|
||||
ss.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean)
|
||||
case None =>
|
||||
(sc.getConf.getBoolean("spark.ssl.enabled", false),
|
||||
sc.getConf.getBoolean("xgboost.spark.ignoreSsl", false))
|
||||
}
|
||||
if (sparkSslEnabled) {
|
||||
if (xgboostSparkIgnoreSsl) {
|
||||
logger.warn(s"spark-xgboost is being run without encrypting data in transit! " +
|
||||
s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.")
|
||||
} else {
|
||||
throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " +
|
||||
"in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " +
|
||||
"To override this protection and still use xgboost-spark at your own risk, " +
|
||||
"you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def parameterFetchAndValidation(params: Map[String, Any], sparkContext: SparkContext) = {
|
||||
val nWorkers = params("num_workers").asInstanceOf[Int]
|
||||
val round = params("num_round").asInstanceOf[Int]
|
||||
val useExternalMemory = params("use_external_memory").asInstanceOf[Boolean]
|
||||
val obj = params.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
||||
val eval = params.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
||||
val missing = params.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
||||
validateSparkSslConf(sparkContext)
|
||||
|
||||
if (params.contains("tree_method")) {
|
||||
require(params("tree_method") == "hist" ||
|
||||
params("tree_method") == "approx" ||
|
||||
params("tree_method") == "auto", "xgboost4j-spark only supports tree_method as 'hist'," +
|
||||
" 'approx' and 'auto'")
|
||||
}
|
||||
if (params.contains("train_test_ratio")) {
|
||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
||||
"'eval_set_names'")
|
||||
}
|
||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||
if (obj != null) {
|
||||
require(params.get("objective_type").isDefined, "parameter \"objective_type\" is not" +
|
||||
" defined, you have to specify the objective type as classification or regression" +
|
||||
" with a customized objective function")
|
||||
}
|
||||
val trackerConf = params.get("tracker_conf") match {
|
||||
case None => TrackerConf()
|
||||
case Some(conf: TrackerConf) => conf
|
||||
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
||||
"instance of TrackerConf.")
|
||||
}
|
||||
val timeoutRequestWorkers: Long = params.get("timeout_request_workers") match {
|
||||
case None => 0L
|
||||
case Some(interval: Long) => interval
|
||||
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
|
||||
" an instance of Long.")
|
||||
}
|
||||
val checkpointParam =
|
||||
CheckpointManager.extractParams(params)
|
||||
(nWorkers, round, useExternalMemory, obj, eval, missing, trackerConf, timeoutRequestWorkers,
|
||||
checkpointParam.checkpointPath, checkpointParam.checkpointInterval,
|
||||
checkpointParam.skipCleanCheckpoint)
|
||||
}
|
||||
|
||||
private def trainForNonRanking(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
params: Map[String, Any],
|
||||
xgbExecutionParams: XGBoostExecutionParams,
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
checkpointRound: Int,
|
||||
prevBooster: Booster,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) =
|
||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||
if (evalSetsMap.isEmpty) {
|
||||
trainingData.mapPartitions(labeledPoints => {
|
||||
val watches = Watches.buildWatches(params,
|
||||
processMissingValues(labeledPoints, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
val watches = Watches.buildWatches(xgbExecutionParams,
|
||||
processMissingValues(labeledPoints, xgbExecutionParams.missing),
|
||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
||||
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
||||
}).cache()
|
||||
} else {
|
||||
coPartitionNoGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions {
|
||||
nameAndLabeledPointSets =>
|
||||
val watches = Watches.buildWatches(
|
||||
nameAndLabeledPointSets.map {
|
||||
case (name, iter) => (name, processMissingValues(iter, missing))},
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
}.cache()
|
||||
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
|
||||
mapPartitions {
|
||||
nameAndLabeledPointSets =>
|
||||
val watches = Watches.buildWatches(
|
||||
nameAndLabeledPointSets.map {
|
||||
case (name, iter) => (name, processMissingValues(iter,
|
||||
xgbExecutionParams.missing))
|
||||
},
|
||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
||||
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
||||
}.cache()
|
||||
}
|
||||
}
|
||||
|
||||
private def trainForRanking(
|
||||
trainingData: RDD[Array[XGBLabeledPoint]],
|
||||
params: Map[String, Any],
|
||||
xgbExecutionParam: XGBoostExecutionParams,
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
checkpointRound: Int,
|
||||
prevBooster: Booster,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) =
|
||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||
if (evalSetsMap.isEmpty) {
|
||||
trainingData.mapPartitions(labeledPointGroups => {
|
||||
val watches = Watches.buildWatchesWithGroup(params,
|
||||
processMissingValuesWithGroup(labeledPointGroups, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
|
||||
val watches = Watches.buildWatchesWithGroup(xgbExecutionParam,
|
||||
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing),
|
||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
||||
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
|
||||
}).cache()
|
||||
} else {
|
||||
coPartitionGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions(
|
||||
coPartitionGroupSets(trainingData, evalSetsMap, xgbExecutionParam.numWorkers).mapPartitions(
|
||||
labeledPointGroupSets => {
|
||||
val watches = Watches.buildWatchesWithGroup(
|
||||
labeledPointGroupSets.map {
|
||||
case (name, iter) => (name, processMissingValuesWithGroup(iter, missing))
|
||||
case (name, iter) => (name, processMissingValuesWithGroup(iter,
|
||||
xgbExecutionParam.missing))
|
||||
},
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval,
|
||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
||||
xgbExecutionParam.obj,
|
||||
xgbExecutionParam.eval,
|
||||
prevBooster)
|
||||
}).cache()
|
||||
}
|
||||
@ -428,31 +491,32 @@ object XGBoost extends Serializable {
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()):
|
||||
(Booster, Map[String, Array[Float]]) = {
|
||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||
val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers,
|
||||
checkpointPath, checkpointInterval, skipCleanCheckpoint) =
|
||||
parameterFetchAndValidation(params,
|
||||
trainingData.sparkContext)
|
||||
val xgbExecParams = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext).
|
||||
buildXGBRuntimeParams
|
||||
val sc = trainingData.sparkContext
|
||||
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
||||
checkpointManager.cleanUpHigherVersions(round.asInstanceOf[Int])
|
||||
val transformedTrainingData = composeInputData(trainingData,
|
||||
params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean], hasGroup, nWorkers)
|
||||
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
|
||||
checkpointPath)
|
||||
checkpointManager.cleanUpHigherVersions(xgbExecParams.round)
|
||||
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
|
||||
hasGroup, xgbExecParams.numWorkers)
|
||||
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
||||
try {
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
val producedBooster = checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
||||
val producedBooster = checkpointManager.getCheckpointRounds(
|
||||
xgbExecParams.checkpointParam.checkpointInterval,
|
||||
xgbExecParams.round).map {
|
||||
checkpointRound: Int =>
|
||||
val tracker = startTracker(nWorkers, trackerConf)
|
||||
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
try {
|
||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers,
|
||||
nWorkers)
|
||||
val parallelismTracker = new SparkParallelismTracker(sc,
|
||||
xgbExecParams.timeoutRequestWorkers,
|
||||
xgbExecParams.numWorkers)
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
val boostersAndMetrics = if (hasGroup) {
|
||||
trainForRanking(transformedTrainingData.left.get, overriddenParams, rabitEnv,
|
||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv,
|
||||
checkpointRound, prevBooster, evalSetsMap)
|
||||
} else {
|
||||
trainForNonRanking(transformedTrainingData.right.get, overriddenParams, rabitEnv,
|
||||
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
|
||||
checkpointRound, prevBooster, evalSetsMap)
|
||||
}
|
||||
val sparkJobThread = new Thread() {
|
||||
@ -467,7 +531,7 @@ object XGBoost extends Serializable {
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
|
||||
boostersAndMetrics, sparkJobThread)
|
||||
if (checkpointRound < round) {
|
||||
if (checkpointRound < xgbExecParams.round) {
|
||||
prevBooster = booster
|
||||
checkpointManager.updateCheckpoint(prevBooster)
|
||||
}
|
||||
@ -477,7 +541,7 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}.last
|
||||
// we should delete the checkpoint directory after a successful training
|
||||
if (!skipCleanCheckpoint) {
|
||||
if (!xgbExecParams.checkpointParam.skipCleanCheckpoint) {
|
||||
checkpointManager.cleanPath()
|
||||
}
|
||||
producedBooster
|
||||
@ -488,8 +552,7 @@ object XGBoost extends Serializable {
|
||||
trainingData.sparkContext.stop()
|
||||
throw t
|
||||
} finally {
|
||||
uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean],
|
||||
transformedTrainingData)
|
||||
uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData)
|
||||
}
|
||||
}
|
||||
|
||||
@ -673,11 +736,11 @@ private object Watches {
|
||||
}
|
||||
|
||||
def buildWatches(
|
||||
params: Map[String, Any],
|
||||
xgbExecutionParams: XGBoostExecutionParams,
|
||||
labeledPoints: Iterator[XGBLabeledPoint],
|
||||
cacheDirName: Option[String]): Watches = {
|
||||
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
|
||||
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
||||
val trainTestRatio = xgbExecutionParams.xgbInputParams.trainTestRatio
|
||||
val seed = xgbExecutionParams.xgbInputParams.seed
|
||||
val r = new Random(seed)
|
||||
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
|
||||
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||
@ -743,11 +806,11 @@ private object Watches {
|
||||
}
|
||||
|
||||
def buildWatchesWithGroup(
|
||||
params: Map[String, Any],
|
||||
xgbExecutionParams: XGBoostExecutionParams,
|
||||
labeledPointGroups: Iterator[Array[XGBLabeledPoint]],
|
||||
cacheDirName: Option[String]): Watches = {
|
||||
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
|
||||
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
||||
val trainTestRatio = xgbExecutionParams.xgbInputParams.trainTestRatio
|
||||
val seed = xgbExecutionParams.xgbInputParams.seed
|
||||
val r = new Random(seed)
|
||||
val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
||||
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user