[jvm-packages] Refactor XGBoost.scala to put all params processing in one place (#4815)
* cleaning checkpoint file after a successful file * address comments * refactor xgboost.scala to avoid multiple changes when adding params * consolidate params * fix compilation issue * fix failed test * fix wrong name * tyep conversion
This commit is contained in:
parent
830e73901d
commit
0184eb5d02
@ -24,6 +24,8 @@ import scala.util.Random
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.XGBoost.logger
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
@ -55,6 +57,172 @@ object TrackerConf {
|
|||||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private[this] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
|
||||||
|
maximizeEvalMetrics: Boolean)
|
||||||
|
|
||||||
|
private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
|
||||||
|
|
||||||
|
private[this] case class XGBoostExecutionParams(
|
||||||
|
numWorkers: Int,
|
||||||
|
round: Int,
|
||||||
|
useExternalMemory: Boolean,
|
||||||
|
obj: ObjectiveTrait,
|
||||||
|
eval: EvalTrait,
|
||||||
|
missing: Float,
|
||||||
|
trackerConf: TrackerConf,
|
||||||
|
timeoutRequestWorkers: Long,
|
||||||
|
checkpointParam: CheckpointParam,
|
||||||
|
xgbInputParams: XGBoostExecutionInputParams,
|
||||||
|
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||||
|
cacheTrainingSet: Boolean) {
|
||||||
|
|
||||||
|
private var rawParamMap: Map[String, Any] = _
|
||||||
|
|
||||||
|
def setRawParamMap(inputMap: Map[String, Any]): Unit = {
|
||||||
|
rawParamMap = inputMap
|
||||||
|
}
|
||||||
|
|
||||||
|
def toMap: Map[String, Any] = {
|
||||||
|
rawParamMap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], sc: SparkContext){
|
||||||
|
|
||||||
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
|
private val overridedParams = overrideParams(rawParams, sc)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
|
||||||
|
* If so, throw an exception unless this safety measure has been explicitly overridden
|
||||||
|
* via conf `xgboost.spark.ignoreSsl`.
|
||||||
|
*/
|
||||||
|
private def validateSparkSslConf: Unit = {
|
||||||
|
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
|
||||||
|
SparkSession.getActiveSession match {
|
||||||
|
case Some(ss) =>
|
||||||
|
(ss.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean,
|
||||||
|
ss.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean)
|
||||||
|
case None =>
|
||||||
|
(sc.getConf.getBoolean("spark.ssl.enabled", false),
|
||||||
|
sc.getConf.getBoolean("xgboost.spark.ignoreSsl", false))
|
||||||
|
}
|
||||||
|
if (sparkSslEnabled) {
|
||||||
|
if (xgboostSparkIgnoreSsl) {
|
||||||
|
logger.warn(s"spark-xgboost is being run without encrypting data in transit! " +
|
||||||
|
s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.")
|
||||||
|
} else {
|
||||||
|
throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " +
|
||||||
|
"in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " +
|
||||||
|
"To override this protection and still use xgboost-spark at your own risk, " +
|
||||||
|
"you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* we should not include any nested structure in the output of this function as the map is
|
||||||
|
* eventually to be feed to xgboost4j layer
|
||||||
|
*/
|
||||||
|
private def overrideParams(
|
||||||
|
params: Map[String, Any],
|
||||||
|
sc: SparkContext): Map[String, Any] = {
|
||||||
|
val coresPerTask = sc.getConf.getInt("spark.task.cpus", 1)
|
||||||
|
var overridedParams = params
|
||||||
|
if (overridedParams.contains("nthread")) {
|
||||||
|
val nThread = overridedParams("nthread").toString.toInt
|
||||||
|
require(nThread <= coresPerTask,
|
||||||
|
s"the nthread configuration ($nThread) must be no larger than " +
|
||||||
|
s"spark.task.cpus ($coresPerTask)")
|
||||||
|
} else {
|
||||||
|
overridedParams = overridedParams + ("nthread" -> coresPerTask)
|
||||||
|
}
|
||||||
|
|
||||||
|
val numEarlyStoppingRounds = overridedParams.getOrElse(
|
||||||
|
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
||||||
|
overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds
|
||||||
|
if (numEarlyStoppingRounds > 0 &&
|
||||||
|
!overridedParams.contains("maximize_evaluation_metrics")) {
|
||||||
|
if (overridedParams.contains("custom_eval")) {
|
||||||
|
throw new IllegalArgumentException("custom_eval does not support early stopping")
|
||||||
|
}
|
||||||
|
val eval_metric = overridedParams("eval_metric").toString
|
||||||
|
val maximize = LearningTaskParams.evalMetricsToMaximize contains eval_metric
|
||||||
|
logger.info("parameter \"maximize_evaluation_metrics\" is set to " + maximize)
|
||||||
|
overridedParams += ("maximize_evaluation_metrics" -> maximize)
|
||||||
|
}
|
||||||
|
overridedParams
|
||||||
|
}
|
||||||
|
|
||||||
|
def buildXGBRuntimeParams: XGBoostExecutionParams = {
|
||||||
|
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
|
||||||
|
val round = overridedParams("num_round").asInstanceOf[Int]
|
||||||
|
val useExternalMemory = overridedParams("use_external_memory").asInstanceOf[Boolean]
|
||||||
|
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
||||||
|
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
||||||
|
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
||||||
|
validateSparkSslConf
|
||||||
|
|
||||||
|
if (overridedParams.contains("tree_method")) {
|
||||||
|
require(overridedParams("tree_method") == "hist" ||
|
||||||
|
overridedParams("tree_method") == "approx" ||
|
||||||
|
overridedParams("tree_method") == "auto", "xgboost4j-spark only supports tree_method as" +
|
||||||
|
" 'hist', 'approx' and 'auto'")
|
||||||
|
}
|
||||||
|
if (overridedParams.contains("train_test_ratio")) {
|
||||||
|
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||||
|
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
||||||
|
"'eval_set_names'")
|
||||||
|
}
|
||||||
|
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||||
|
if (obj != null) {
|
||||||
|
require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " +
|
||||||
|
"is not defined, you have to specify the objective type as classification or regression" +
|
||||||
|
" with a customized objective function")
|
||||||
|
}
|
||||||
|
val trackerConf = overridedParams.get("tracker_conf") match {
|
||||||
|
case None => TrackerConf()
|
||||||
|
case Some(conf: TrackerConf) => conf
|
||||||
|
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
||||||
|
"instance of TrackerConf.")
|
||||||
|
}
|
||||||
|
val timeoutRequestWorkers: Long = overridedParams.get("timeout_request_workers") match {
|
||||||
|
case None => 0L
|
||||||
|
case Some(interval: Long) => interval
|
||||||
|
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
|
||||||
|
" an instance of Long.")
|
||||||
|
}
|
||||||
|
val checkpointParam =
|
||||||
|
CheckpointManager.extractParams(overridedParams)
|
||||||
|
|
||||||
|
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
|
||||||
|
.asInstanceOf[Double]
|
||||||
|
val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long]
|
||||||
|
val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed)
|
||||||
|
|
||||||
|
val earlyStoppingRounds = overridedParams.getOrElse(
|
||||||
|
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
||||||
|
val maximizeEvalMetrics = overridedParams.getOrElse(
|
||||||
|
"maximize_evaluation_metrics", true).asInstanceOf[Boolean]
|
||||||
|
val xgbExecEarlyStoppingParams = XGBoostExecutionEarlyStoppingParams(earlyStoppingRounds,
|
||||||
|
maximizeEvalMetrics)
|
||||||
|
|
||||||
|
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
|
||||||
|
.asInstanceOf[Boolean]
|
||||||
|
|
||||||
|
val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
|
||||||
|
missing, trackerConf,
|
||||||
|
timeoutRequestWorkers,
|
||||||
|
checkpointParam,
|
||||||
|
inputParams,
|
||||||
|
xgbExecEarlyStoppingParams,
|
||||||
|
cacheTrainingSet)
|
||||||
|
xgbExecParam.setRawParamMap(overridedParams)
|
||||||
|
xgbExecParam
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Traing data group in a RDD partition.
|
* Traing data group in a RDD partition.
|
||||||
* @param groupId The group id
|
* @param groupId The group id
|
||||||
@ -136,7 +304,7 @@ object XGBoost extends Serializable {
|
|||||||
|
|
||||||
private def buildDistributedBooster(
|
private def buildDistributedBooster(
|
||||||
watches: Watches,
|
watches: Watches,
|
||||||
params: Map[String, Any],
|
xgbExecutionParam: XGBoostExecutionParams,
|
||||||
rabitEnv: java.util.Map[String, String],
|
rabitEnv: java.util.Map[String, String],
|
||||||
round: Int,
|
round: Int,
|
||||||
obj: ObjectiveTrait,
|
obj: ObjectiveTrait,
|
||||||
@ -157,24 +325,9 @@ object XGBoost extends Serializable {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
Rabit.init(rabitEnv)
|
Rabit.init(rabitEnv)
|
||||||
|
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
||||||
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
|
||||||
.map(_.toString.toInt).getOrElse(0)
|
|
||||||
val overridedParams = if (numEarlyStoppingRounds > 0 &&
|
|
||||||
!params.contains("maximize_evaluation_metrics")) {
|
|
||||||
if (params.contains("custom_eval")) {
|
|
||||||
throw new IllegalArgumentException("maximize_evaluation_metrics has to be "
|
|
||||||
+ "specified when custom_eval is set")
|
|
||||||
}
|
|
||||||
val eval_metric = params("eval_metric").toString
|
|
||||||
val maximize = LearningTaskParams.evalMetricsToMaximize contains eval_metric
|
|
||||||
logger.info("parameter \"maximize_evaluation_metrics\" is set to " + maximize)
|
|
||||||
params + ("maximize_evaluation_metrics" -> maximize)
|
|
||||||
} else {
|
|
||||||
params
|
|
||||||
}
|
|
||||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
||||||
val booster = SXGBoost.train(watches.toMap("train"), overridedParams, round,
|
val booster = SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, round,
|
||||||
watches.toMap, metrics, obj, eval,
|
watches.toMap, metrics, obj, eval,
|
||||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||||
@ -188,22 +341,6 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def overrideParamsAccordingToTaskCPUs(
|
|
||||||
params: Map[String, Any],
|
|
||||||
sc: SparkContext): Map[String, Any] = {
|
|
||||||
val coresPerTask = sc.getConf.getInt("spark.task.cpus", 1)
|
|
||||||
var overridedParams = params
|
|
||||||
if (overridedParams.contains("nthread")) {
|
|
||||||
val nThread = overridedParams("nthread").toString.toInt
|
|
||||||
require(nThread <= coresPerTask,
|
|
||||||
s"the nthread configuration ($nThread) must be no larger than " +
|
|
||||||
s"spark.task.cpus ($coresPerTask)")
|
|
||||||
} else {
|
|
||||||
overridedParams = params + ("nthread" -> coresPerTask)
|
|
||||||
}
|
|
||||||
overridedParams
|
|
||||||
}
|
|
||||||
|
|
||||||
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||||
val tracker: IRabitTracker = trackerConf.trackerImpl match {
|
val tracker: IRabitTracker = trackerConf.trackerImpl match {
|
||||||
case "scala" => new RabitTracker(nWorkers)
|
case "scala" => new RabitTracker(nWorkers)
|
||||||
@ -261,138 +398,64 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
|
|
||||||
* If so, throw an exception unless this safety measure has been explicitly overridden
|
|
||||||
* via conf `xgboost.spark.ignoreSsl`.
|
|
||||||
*
|
|
||||||
* @param sc SparkContext for the training dataset. When looking for the confs, this method
|
|
||||||
* first checks for an active SparkSession. If one is not available, it falls back
|
|
||||||
* to this SparkContext.
|
|
||||||
*/
|
|
||||||
private def validateSparkSslConf(sc: SparkContext): Unit = {
|
|
||||||
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
|
|
||||||
SparkSession.getActiveSession match {
|
|
||||||
case Some(ss) =>
|
|
||||||
(ss.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean,
|
|
||||||
ss.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean)
|
|
||||||
case None =>
|
|
||||||
(sc.getConf.getBoolean("spark.ssl.enabled", false),
|
|
||||||
sc.getConf.getBoolean("xgboost.spark.ignoreSsl", false))
|
|
||||||
}
|
|
||||||
if (sparkSslEnabled) {
|
|
||||||
if (xgboostSparkIgnoreSsl) {
|
|
||||||
logger.warn(s"spark-xgboost is being run without encrypting data in transit! " +
|
|
||||||
s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.")
|
|
||||||
} else {
|
|
||||||
throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " +
|
|
||||||
"in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " +
|
|
||||||
"To override this protection and still use xgboost-spark at your own risk, " +
|
|
||||||
"you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def parameterFetchAndValidation(params: Map[String, Any], sparkContext: SparkContext) = {
|
|
||||||
val nWorkers = params("num_workers").asInstanceOf[Int]
|
|
||||||
val round = params("num_round").asInstanceOf[Int]
|
|
||||||
val useExternalMemory = params("use_external_memory").asInstanceOf[Boolean]
|
|
||||||
val obj = params.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
|
||||||
val eval = params.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
|
||||||
val missing = params.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
|
||||||
validateSparkSslConf(sparkContext)
|
|
||||||
|
|
||||||
if (params.contains("tree_method")) {
|
|
||||||
require(params("tree_method") == "hist" ||
|
|
||||||
params("tree_method") == "approx" ||
|
|
||||||
params("tree_method") == "auto", "xgboost4j-spark only supports tree_method as 'hist'," +
|
|
||||||
" 'approx' and 'auto'")
|
|
||||||
}
|
|
||||||
if (params.contains("train_test_ratio")) {
|
|
||||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
|
||||||
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
|
||||||
"'eval_set_names'")
|
|
||||||
}
|
|
||||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
|
||||||
if (obj != null) {
|
|
||||||
require(params.get("objective_type").isDefined, "parameter \"objective_type\" is not" +
|
|
||||||
" defined, you have to specify the objective type as classification or regression" +
|
|
||||||
" with a customized objective function")
|
|
||||||
}
|
|
||||||
val trackerConf = params.get("tracker_conf") match {
|
|
||||||
case None => TrackerConf()
|
|
||||||
case Some(conf: TrackerConf) => conf
|
|
||||||
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
|
||||||
"instance of TrackerConf.")
|
|
||||||
}
|
|
||||||
val timeoutRequestWorkers: Long = params.get("timeout_request_workers") match {
|
|
||||||
case None => 0L
|
|
||||||
case Some(interval: Long) => interval
|
|
||||||
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
|
|
||||||
" an instance of Long.")
|
|
||||||
}
|
|
||||||
val checkpointParam =
|
|
||||||
CheckpointManager.extractParams(params)
|
|
||||||
(nWorkers, round, useExternalMemory, obj, eval, missing, trackerConf, timeoutRequestWorkers,
|
|
||||||
checkpointParam.checkpointPath, checkpointParam.checkpointInterval,
|
|
||||||
checkpointParam.skipCleanCheckpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def trainForNonRanking(
|
private def trainForNonRanking(
|
||||||
trainingData: RDD[XGBLabeledPoint],
|
trainingData: RDD[XGBLabeledPoint],
|
||||||
params: Map[String, Any],
|
xgbExecutionParams: XGBoostExecutionParams,
|
||||||
rabitEnv: java.util.Map[String, String],
|
rabitEnv: java.util.Map[String, String],
|
||||||
checkpointRound: Int,
|
checkpointRound: Int,
|
||||||
prevBooster: Booster,
|
prevBooster: Booster,
|
||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) =
|
|
||||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
|
||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
trainingData.mapPartitions(labeledPoints => {
|
trainingData.mapPartitions(labeledPoints => {
|
||||||
val watches = Watches.buildWatches(params,
|
val watches = Watches.buildWatches(xgbExecutionParams,
|
||||||
processMissingValues(labeledPoints, missing),
|
processMissingValues(labeledPoints, xgbExecutionParams.missing),
|
||||||
getCacheDirName(useExternalMemory))
|
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
||||||
obj, eval, prevBooster)
|
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
||||||
}).cache()
|
}).cache()
|
||||||
} else {
|
} else {
|
||||||
coPartitionNoGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions {
|
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
|
||||||
|
mapPartitions {
|
||||||
nameAndLabeledPointSets =>
|
nameAndLabeledPointSets =>
|
||||||
val watches = Watches.buildWatches(
|
val watches = Watches.buildWatches(
|
||||||
nameAndLabeledPointSets.map {
|
nameAndLabeledPointSets.map {
|
||||||
case (name, iter) => (name, processMissingValues(iter, missing))},
|
case (name, iter) => (name, processMissingValues(iter,
|
||||||
getCacheDirName(useExternalMemory))
|
xgbExecutionParams.missing))
|
||||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
},
|
||||||
obj, eval, prevBooster)
|
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||||
|
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
||||||
|
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
||||||
}.cache()
|
}.cache()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def trainForRanking(
|
private def trainForRanking(
|
||||||
trainingData: RDD[Array[XGBLabeledPoint]],
|
trainingData: RDD[Array[XGBLabeledPoint]],
|
||||||
params: Map[String, Any],
|
xgbExecutionParam: XGBoostExecutionParams,
|
||||||
rabitEnv: java.util.Map[String, String],
|
rabitEnv: java.util.Map[String, String],
|
||||||
checkpointRound: Int,
|
checkpointRound: Int,
|
||||||
prevBooster: Booster,
|
prevBooster: Booster,
|
||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) =
|
|
||||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
|
||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
trainingData.mapPartitions(labeledPointGroups => {
|
trainingData.mapPartitions(labeledPointGroups => {
|
||||||
val watches = Watches.buildWatchesWithGroup(params,
|
val watches = Watches.buildWatchesWithGroup(xgbExecutionParam,
|
||||||
processMissingValuesWithGroup(labeledPointGroups, missing),
|
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing),
|
||||||
getCacheDirName(useExternalMemory))
|
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
|
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
||||||
|
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
|
||||||
}).cache()
|
}).cache()
|
||||||
} else {
|
} else {
|
||||||
coPartitionGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions(
|
coPartitionGroupSets(trainingData, evalSetsMap, xgbExecutionParam.numWorkers).mapPartitions(
|
||||||
labeledPointGroupSets => {
|
labeledPointGroupSets => {
|
||||||
val watches = Watches.buildWatchesWithGroup(
|
val watches = Watches.buildWatchesWithGroup(
|
||||||
labeledPointGroupSets.map {
|
labeledPointGroupSets.map {
|
||||||
case (name, iter) => (name, processMissingValuesWithGroup(iter, missing))
|
case (name, iter) => (name, processMissingValuesWithGroup(iter,
|
||||||
|
xgbExecutionParam.missing))
|
||||||
},
|
},
|
||||||
getCacheDirName(useExternalMemory))
|
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval,
|
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
||||||
|
xgbExecutionParam.obj,
|
||||||
|
xgbExecutionParam.eval,
|
||||||
prevBooster)
|
prevBooster)
|
||||||
}).cache()
|
}).cache()
|
||||||
}
|
}
|
||||||
@ -428,31 +491,32 @@ object XGBoost extends Serializable {
|
|||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()):
|
evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()):
|
||||||
(Booster, Map[String, Array[Float]]) = {
|
(Booster, Map[String, Array[Float]]) = {
|
||||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||||
val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers,
|
val xgbExecParams = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext).
|
||||||
checkpointPath, checkpointInterval, skipCleanCheckpoint) =
|
buildXGBRuntimeParams
|
||||||
parameterFetchAndValidation(params,
|
|
||||||
trainingData.sparkContext)
|
|
||||||
val sc = trainingData.sparkContext
|
val sc = trainingData.sparkContext
|
||||||
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
|
||||||
checkpointManager.cleanUpHigherVersions(round.asInstanceOf[Int])
|
checkpointPath)
|
||||||
val transformedTrainingData = composeInputData(trainingData,
|
checkpointManager.cleanUpHigherVersions(xgbExecParams.round)
|
||||||
params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean], hasGroup, nWorkers)
|
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
|
||||||
|
hasGroup, xgbExecParams.numWorkers)
|
||||||
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
||||||
try {
|
try {
|
||||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||||
val producedBooster = checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
val producedBooster = checkpointManager.getCheckpointRounds(
|
||||||
|
xgbExecParams.checkpointParam.checkpointInterval,
|
||||||
|
xgbExecParams.round).map {
|
||||||
checkpointRound: Int =>
|
checkpointRound: Int =>
|
||||||
val tracker = startTracker(nWorkers, trackerConf)
|
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||||
try {
|
try {
|
||||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
val parallelismTracker = new SparkParallelismTracker(sc,
|
||||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers,
|
xgbExecParams.timeoutRequestWorkers,
|
||||||
nWorkers)
|
xgbExecParams.numWorkers)
|
||||||
val rabitEnv = tracker.getWorkerEnvs
|
val rabitEnv = tracker.getWorkerEnvs
|
||||||
val boostersAndMetrics = if (hasGroup) {
|
val boostersAndMetrics = if (hasGroup) {
|
||||||
trainForRanking(transformedTrainingData.left.get, overriddenParams, rabitEnv,
|
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv,
|
||||||
checkpointRound, prevBooster, evalSetsMap)
|
checkpointRound, prevBooster, evalSetsMap)
|
||||||
} else {
|
} else {
|
||||||
trainForNonRanking(transformedTrainingData.right.get, overriddenParams, rabitEnv,
|
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
|
||||||
checkpointRound, prevBooster, evalSetsMap)
|
checkpointRound, prevBooster, evalSetsMap)
|
||||||
}
|
}
|
||||||
val sparkJobThread = new Thread() {
|
val sparkJobThread = new Thread() {
|
||||||
@ -467,7 +531,7 @@ object XGBoost extends Serializable {
|
|||||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||||
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
|
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
|
||||||
boostersAndMetrics, sparkJobThread)
|
boostersAndMetrics, sparkJobThread)
|
||||||
if (checkpointRound < round) {
|
if (checkpointRound < xgbExecParams.round) {
|
||||||
prevBooster = booster
|
prevBooster = booster
|
||||||
checkpointManager.updateCheckpoint(prevBooster)
|
checkpointManager.updateCheckpoint(prevBooster)
|
||||||
}
|
}
|
||||||
@ -477,7 +541,7 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}.last
|
}.last
|
||||||
// we should delete the checkpoint directory after a successful training
|
// we should delete the checkpoint directory after a successful training
|
||||||
if (!skipCleanCheckpoint) {
|
if (!xgbExecParams.checkpointParam.skipCleanCheckpoint) {
|
||||||
checkpointManager.cleanPath()
|
checkpointManager.cleanPath()
|
||||||
}
|
}
|
||||||
producedBooster
|
producedBooster
|
||||||
@ -488,8 +552,7 @@ object XGBoost extends Serializable {
|
|||||||
trainingData.sparkContext.stop()
|
trainingData.sparkContext.stop()
|
||||||
throw t
|
throw t
|
||||||
} finally {
|
} finally {
|
||||||
uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean],
|
uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData)
|
||||||
transformedTrainingData)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -673,11 +736,11 @@ private object Watches {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def buildWatches(
|
def buildWatches(
|
||||||
params: Map[String, Any],
|
xgbExecutionParams: XGBoostExecutionParams,
|
||||||
labeledPoints: Iterator[XGBLabeledPoint],
|
labeledPoints: Iterator[XGBLabeledPoint],
|
||||||
cacheDirName: Option[String]): Watches = {
|
cacheDirName: Option[String]): Watches = {
|
||||||
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
|
val trainTestRatio = xgbExecutionParams.xgbInputParams.trainTestRatio
|
||||||
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
val seed = xgbExecutionParams.xgbInputParams.seed
|
||||||
val r = new Random(seed)
|
val r = new Random(seed)
|
||||||
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
|
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
|
||||||
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||||
@ -743,11 +806,11 @@ private object Watches {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def buildWatchesWithGroup(
|
def buildWatchesWithGroup(
|
||||||
params: Map[String, Any],
|
xgbExecutionParams: XGBoostExecutionParams,
|
||||||
labeledPointGroups: Iterator[Array[XGBLabeledPoint]],
|
labeledPointGroups: Iterator[Array[XGBLabeledPoint]],
|
||||||
cacheDirName: Option[String]): Watches = {
|
cacheDirName: Option[String]): Watches = {
|
||||||
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
|
val trainTestRatio = xgbExecutionParams.xgbInputParams.trainTestRatio
|
||||||
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
val seed = xgbExecutionParams.xgbInputParams.seed
|
||||||
val r = new Random(seed)
|
val r = new Random(seed)
|
||||||
val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
|
||||||
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user