[jvm-packages][refactor] refactor XGBoost.scala (spark) (#3904)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* wrap iterators

* remove unused code

* refactor

* fix typo
This commit is contained in:
Nan Zhu 2018-11-15 20:38:28 -08:00 committed by GitHub
parent 0cd326c1bc
commit aa48b7e903
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 50 deletions

View File

@ -26,9 +26,9 @@ import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
import org.apache.spark.sql.SparkSession
@ -207,21 +207,14 @@ object XGBoost extends Serializable {
}
}
/**
* @return A tuple of the booster and the metrics used to build training summary
*/
@throws(classOf[XGBoostError])
private[spark] def trainDistributed(
trainingData: RDD[XGBLabeledPoint],
params: Map[String, Any],
round: Int,
nWorkers: Int,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
useExternalMemory: Boolean = false,
missing: Float = Float.NaN,
hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = {
validateSparkSslConf(trainingData.context)
private def parameterFetchAndValidation(params: Map[String, Any], sparkContext: SparkContext) = {
val nWorkers = params("num_workers").asInstanceOf[Int]
val round = params("num_round").asInstanceOf[Int]
val useExternalMemory = params("use_external_memory").asInstanceOf[Boolean]
val obj = params.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
val eval = params.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
val missing = params.getOrElse("missing", Float.NaN).asInstanceOf[Float]
validateSparkSslConf(sparkContext)
if (params.contains("tree_method")) {
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
" for now")
@ -245,11 +238,60 @@ object XGBoost extends Serializable {
" an instance of Long.")
}
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
(nWorkers, round, useExternalMemory, obj, eval, missing, trackerConf, timeoutRequestWorkers,
checkpointPath, checkpointInterval)
}
private def trainForNonRanking(
trainingData: RDD[XGBLabeledPoint],
params: Map[String, Any],
rabitEnv: java.util.Map[String, String],
checkpointRound: Int,
prevBooster: Booster) = {
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
parameterFetchAndValidation(params, trainingData.sparkContext)
val partitionedData = repartitionForTraining(trainingData, nWorkers)
partitionedData.mapPartitions(labeledPoints => {
val watches = Watches.buildWatches(params,
removeMissingValues(labeledPoints, missing),
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
obj, eval, prevBooster)
}).cache()
}
private def trainForRanking(
trainingData: RDD[XGBLabeledPoint],
params: Map[String, Any],
rabitEnv: java.util.Map[String, String],
checkpointRound: Int,
prevBooster: Booster) = {
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
parameterFetchAndValidation(params, trainingData.sparkContext)
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
partitionedData.mapPartitions(labeledPointGroups => {
val watches = Watches.buildWatchesWithGroup(params,
removeMissingValuesWithGroup(labeledPointGroups, missing),
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
obj, eval, prevBooster)
}).cache()
}
/**
* @return A tuple of the booster and the metrics used to build training summary
*/
@throws(classOf[XGBoostError])
private[spark] def trainDistributed(
trainingData: RDD[XGBLabeledPoint],
params: Map[String, Any],
hasGroup: Boolean = false): (Booster, Map[String, Array[Float]]) = {
val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers,
checkpointPath, checkpointInterval) = parameterFetchAndValidation(params,
trainingData.sparkContext)
val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, checkpointPath)
checkpointManager.cleanUpHigherVersions(round)
checkpointManager.cleanUpHigherVersions(round.asInstanceOf[Int])
var prevBooster = checkpointManager.loadCheckpointAsBooster
// Train for every ${savingRound} rounds and save the partially completed booster
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
@ -259,27 +301,12 @@ object XGBoost extends Serializable {
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
val rabitEnv = tracker.getWorkerEnvs
val boostersAndMetrics = hasGroup match {
case true => {
val partitionedData = repartitionForTrainingGroup(trainingData, nWorkers)
partitionedData.mapPartitions(labeledPointGroups => {
val watches = Watches.buildWatchesWithGroup(overriddenParams,
removeMissingValuesWithGroup(labeledPointGroups, missing),
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound,
obj, eval, prevBooster)
}).cache()
}
case false => {
val partitionedData = repartitionForTraining(trainingData, nWorkers)
partitionedData.mapPartitions(labeledPoints => {
val watches = Watches.buildWatches(overriddenParams,
removeMissingValues(labeledPoints, missing),
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, overriddenParams, rabitEnv, checkpointRound,
obj, eval, prevBooster)
}).cache()
}
val boostersAndMetrics = if (hasGroup) {
trainForRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
prevBooster)
} else {
trainForNonRanking(trainingData, overriddenParams, rabitEnv, checkpointRound,
prevBooster)
}
val sparkJobThread = new Thread() {
override def run() {

View File

@ -198,8 +198,7 @@ class XGBoostClassifier (
val derivedXGBParamMap = MLlib2XGBoostParams
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
$(missing), hasGroup = false)
hasGroup = false)
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
val summary = XGBoostTrainingSummary(_metrics)
model.setSummary(summary)

View File

@ -193,8 +193,7 @@ class XGBoostRegressor (
val derivedXGBParamMap = MLlib2XGBoostParams
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
$(missing), hasGroup = group != lit(-1))
hasGroup = group != lit(-1))
val model = new XGBoostRegressionModel(uid, _booster)
val summary = XGBoostTrainingSummary(_metrics)
model.setSummary(summary)

View File

@ -78,10 +78,10 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
val (booster, metrics) = XGBoost.trainDistributed(
trainingRDD,
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap,
round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false,
hasGroup = false, missing = Float.NaN)
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
"missing" -> Float.NaN).toMap,
hasGroup = false)
assert(booster != null)
}
@ -270,9 +270,10 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
val (booster, metrics) = XGBoost.trainDistributed(
trainingRDD,
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap,
round = 5, nWorkers = numWorkers, eval = null, obj = null, useExternalMemory = false,
hasGroup = true, missing = Float.NaN)
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
"missing" -> Float.NaN).toMap,
hasGroup = true)
assert(booster != null)
}