[jvm-packages][refactor] refactor XGBoost.scala (spark) (#3904)
* add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * wrap iterators * remove unused code * refactor * fix typo
This commit is contained in:
parent
0cd326c1bc
commit
aa48b7e903
@ -26,9 +26,9 @@ import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.commons.io.FileUtils
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
@ -207,21 +207,14 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A tuple of the booster and the metrics used to build training summary
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
private[spark] def trainDistributed(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
nWorkers: Int,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false,
|
||||
missing: Float = Float.NaN,
|
||||
hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = {
|
||||
validateSparkSslConf(trainingData.context)
|
||||
private def parameterFetchAndValidation(params: Map[String, Any], sparkContext: SparkContext) = {
|
||||
val nWorkers = params("num_workers").asInstanceOf[Int]
|
||||
val round = params("num_round").asInstanceOf[Int]
|
||||
val useExternalMemory = params("use_external_memory").asInstanceOf[Boolean]
|
||||
val obj = params.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
||||
val eval = params.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
||||
val missing = params.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
||||
validateSparkSslConf(sparkContext)
|
||||
if (params.contains("tree_method")) {
|
||||
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
||||
" for now")
|
||||
@ -245,11 +238,60 @@ object XGBoost extends Serializable {
|
||||
" an instance of Long.")
|
||||
}
|
||||
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
|
||||
(nWorkers, round, useExternalMemory, obj, eval, missing, trackerConf, timeoutRequestWorkers,
|
||||
checkpointPath, checkpointInterval)
|
||||
}
|
||||
|
||||
private def trainForNonRanking(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
params: Map[String, Any],
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
checkpointRound: Int,
|
||||
prevBooster: Booster) = {
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||
partitionedData.mapPartitions(labeledPoints => {
|
||||
val watches = Watches.buildWatches(params,
|
||||
removeMissingValues(labeledPoints, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
}).cache()
|
||||
}
|
||||
|
||||
private def trainForRanking(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
params: Map[String, Any],
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
checkpointRound: Int,
|
||||
prevBooster: Booster) = {
|
||||
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
|
||||
parameterFetchAndValidation(params, trainingData.sparkContext)
|
||||
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
|
||||
partitionedData.mapPartitions(labeledPointGroups => {
|
||||
val watches = Watches.buildWatchesWithGroup(params,
|
||||
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
}).cache()
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A tuple of the booster and the metrics used to build training summary
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
private[spark] def trainDistributed(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
params: Map[String, Any],
|
||||
hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = {
|
||||
val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers,
|
||||
checkpointPath, checkpointInterval) = parameterFetchAndValidation(params,
|
||||
trainingData.sparkContext)
|
||||
val sc = trainingData.sparkContext
|
||||
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
||||
checkpointManager.cleanUpHigherVersions(round)
|
||||
|
||||
checkpointManager.cleanUpHigherVersions(round.asInstanceOf[Int])
|
||||
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
|
||||
@ -259,27 +301,12 @@ object XGBoost extends Serializable {
|
||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
val boostersAndMetrics = hasGroup match {
|
||||
case true => {
|
||||
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
|
||||
partitionedData.mapPartitions(labeledPointGroups => {
|
||||
val watches = Watches.buildWatchesWithGroup(overriddenParams,
|
||||
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
}).cache()
|
||||
}
|
||||
case false => {
|
||||
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||
partitionedData.mapPartitions(labeledPoints => {
|
||||
val watches = Watches.buildWatches(overriddenParams,
|
||||
removeMissingValues(labeledPoints, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
}).cache()
|
||||
}
|
||||
val boostersAndMetrics = if (hasGroup) {
|
||||
trainForRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
|
||||
prevBooster)
|
||||
} else {
|
||||
trainForNonRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
|
||||
prevBooster)
|
||||
}
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
|
||||
@ -198,8 +198,7 @@ class XGBoostClassifier (
|
||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
||||
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||
$(missing), hasGroup = false)
|
||||
hasGroup = false)
|
||||
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
|
||||
val summary = XGBoostTrainingSummary(_metrics)
|
||||
model.setSummary(summary)
|
||||
|
||||
@ -193,8 +193,7 @@ class XGBoostRegressor (
|
||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
||||
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||
$(missing), hasGroup = group != lit(-1))
|
||||
hasGroup = group != lit(-1))
|
||||
val model = new XGBoostRegressionModel(uid, _booster)
|
||||
val summary = XGBoostTrainingSummary(_metrics)
|
||||
model.setSummary(summary)
|
||||
|
||||
@ -78,10 +78,10 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
val (booster, metrics) = XGBoost.trainDistributed(
|
||||
trainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap,
|
||||
round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false,
|
||||
hasGroup = false, missing = Float.NaN)
|
||||
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
|
||||
"missing" -> Float.NaN).toMap,
|
||||
hasGroup = false)
|
||||
assert(booster != null)
|
||||
}
|
||||
|
||||
@ -270,9 +270,10 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
val (booster, metrics) = XGBoost.trainDistributed(
|
||||
trainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap,
|
||||
round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false,
|
||||
hasGroup = true, missing = Float.NaN)
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
|
||||
"missing" -> Float.NaN).toMap,
|
||||
hasGroup = true)
|
||||
|
||||
assert(booster != null)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user